diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000000..2f7896d1d136 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 000000000000..08df965ae6bb --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,385 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/pom.xml b/examples/pom.xml new file mode 100644 index 000000000000..fcb52fcdbf8d --- /dev/null +++ b/examples/pom.xml @@ -0,0 +1,223 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + manual_build + + + com.google.cloud.dataflow + google-cloud-dataflow-java-examples-all + Google Cloud Dataflow Java Examples - All + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes all Dataflow Java SDK + examples. + http://cloud.google.com/dataflow + + manual_build + + jar + + + + DataflowPipelineTests + + true + com.google.cloud.dataflow.sdk.testing.RunnableOnService + both + + + + + + + + maven-compiler-plugin + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + ../checkstyle.xml + true + true + true + + + + + check + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + compile + + jar + + + + attach-test-sources + test-compile + + test-jar + + + + + + + org.apache.felix + maven-bundle-plugin + 2.4.0 + true + + ${project.artifactId}-bundled-${project.version} + + + *;scope=compile|runtime;artifactId=!google-cloud-dataflow-java-sdk-all;inline=true + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + dataflow-examples-compile + compile + + jar + + + + dataflow-examples-test-compile + test-compile + + test-jar + + + + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${project.version} + + + + com.google.apis + google-api-services-storage + v1-rev11-1.19.0 + + + + com.google.apis + google-api-services-bigquery + v2-rev167-1.19.0 + + + + com.google.guava + guava-jdk5 + + + + + + com.google.http-client + google-http-client-jackson2 + 1.19.0 + + + + com.fasterxml.jackson.core + jackson-core + 2.4.2 + + + + com.fasterxml.jackson.core + jackson-annotations + 2.4.2 + + + + + org.slf4j + slf4j-api + 1.7.7 + + + + org.slf4j + slf4j-jdk14 + 1.7.7 + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${project.version} + test-jar + test + + + + org.hamcrest + hamcrest-all + 1.3 + test + + + + junit + junit + 4.11 + test + + + diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/BigQueryTornadoes.java b/examples/src/main/java/com/google/cloud/dataflow/examples/BigQueryTornadoes.java new file mode 100644 index 000000000000..43e94c08633b --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/BigQueryTornadoes.java @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.util.ArrayList; +import java.util.List; + +/** + * An example that reads the public samples of weather data from BigQuery, counts the number of + * tornadoes that occur in each month, and writes the results to BigQuery. + */ +public class BigQueryTornadoes { + // Default to using a 1000 row subset of the public weather station table publicdata:samples.gsod. + private static final String WEATHER_SAMPLES_TABLE = + "clouddataflow-readonly:samples.weather_stations"; + + /** + * Examines each row in the input table. If a tornado was recorded in that sample, the month in + * which it occurred is output. + */ + static class ExtractTornadoesFn extends DoFn { + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + if ((Boolean) row.get("tornado")) { + c.output(Integer.parseInt((String) row.get("month"))); + } + } + } + + /** + * Prepares the data for writing to BigQuery by building a TableRow object containing an + * integer representation of month and the number of tornadoes that occurred in each month. + */ + static class FormatCountsFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("month", c.element().getKey().intValue()) + .set("tornado_count", c.element().getValue().longValue()); + c.output(row); + } + } + + /** + * Takes rows from a table and generates a table of counts. + * + * The input schema is described by + * https://developers.google.com/bigquery/docs/dataset-gsod . + * The output contains the total number of tornadoes found in each month in + * the following schema: + *
    + *
  • month: integer
  • + *
  • tornado_count: integer
  • + *
+ */ + static class CountTornadoes + extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection rows) { + + // row... => month... + PCollection tornadoes = rows.apply( + ParDo.of(new ExtractTornadoesFn())); + + // month... => ... + PCollection> tornadoCounts = + tornadoes.apply(Count.perElement()); + + // ... => row... + PCollection results = tornadoCounts.apply( + ParDo.of(new FormatCountsFn())); + + return results; + } + } + + /** + * Options supported by {@link BigQueryTornadoes}. + *

+ * Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Table to read from, specified as " + + ":.") + @Default.String(WEATHER_SAMPLES_TABLE) + String getInput(); + void setInput(String value); + + @Description("Table to write to, specified as " + + ":.") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + Pipeline p = Pipeline.create(options); + + // Build the table schema for the output table. + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("month").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("tornado_count").setType("INTEGER")); + TableSchema schema = new TableSchema().setFields(fields); + + p.apply(BigQueryIO.Read.from(options.getInput())) + .apply(new CountTornadoes()) + .apply(BigQueryIO.Write + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/DatastoreWordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/DatastoreWordCount.java new file mode 100644 index 000000000000..1e00589281aa --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/DatastoreWordCount.java @@ -0,0 +1,198 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.Key; +import com.google.api.services.datastore.DatastoreV1.Property; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.DatastoreV1.Value; +import com.google.api.services.datastore.client.DatastoreHelper; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.DatastoreIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import java.util.Map; + +/** + * A WordCount example using DatastoreIO. + * + *

This example shows how to use DatastoreIO to read from Datastore and + * write the results to Cloud Storage. Note that this example will write + * data to Datastore, which may incur charge for Datastore operations. + * + *

To run this example, users need to set up the environment and use gcloud + * to get credential for Datastore: + *

+ * $ export CLOUDSDK_EXTRA_SCOPES=https://www.googleapis.com/auth/datastore
+ * $ gcloud auth login
+ * 
+ * + *

Note that the environment variable CLOUDSDK_EXTRA_SCOPES must be set + * to the same value when executing a Datastore pipeline, as the local auth + * cache is keyed by the requested scopes. + * + *

To run this pipeline locally, the following options must be provided: + *

{@code
+ *   --project=
+ *   --dataset=
+ *   --output=[ | gs://]
+ * }
+ * + *

To run this example using Dataflow service, you must additionally + * provide either {@literal --stagingLocation} or {@literal --tempLocation}, and + * select one of the Dataflow pipeline runners, eg + * {@literal --runner=BlockingDataflowPipelineRunner}. + */ +public class DatastoreWordCount { + + /** + * A DoFn that gets the content of an entity (one line in a + * Shakespeare play) and converts it to a string. + */ + static class GetContentFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Map props = DatastoreHelper.getPropertyMap(c.element()); + c.output(DatastoreHelper.getString(props.get("content"))); + } + } + + /** + * A DoFn that creates entity for every line in Shakespeare. + */ + static class CreateEntityFn extends DoFn { + private String kind; + + CreateEntityFn(String kind) { + this.kind = kind; + } + + public Entity makeEntity(String content) { + Entity.Builder entityBuilder = Entity.newBuilder(); + // Create entities with same ancestor Key. + Key ancestorKey = DatastoreHelper.makeKey(kind, "root").build(); + Key key = DatastoreHelper.makeKey(ancestorKey, kind).build(); + entityBuilder.setKey(key); + entityBuilder.addProperty(Property.newBuilder() + .setName("content") + .setValue(Value.newBuilder().setStringValue(content))); + return entityBuilder.build(); + } + + @Override + public void processElement(ProcessContext c) { + c.output(makeEntity(c.element())); + } + } + + /** + * Options supported by {@link DatastoreWordCount}. + *

+ * Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Path of the file to read from and store to Datastore") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInput(); + void setInput(String value); + + @Description("Path of the file to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + + @Description("Dataset ID to read from datastore") + @Validation.Required + String getDataset(); + void setDataset(String value); + + @Description("Dataset entity kind") + @Default.String("shakespeare-demo") + String getKind(); + void setKind(String value); + + @Description("Read an existing dataset, do not write first") + boolean isReadOnly(); + void setReadOnly(boolean value); + } + + /** + * An example which creates a pipeline to populate DatastoreIO from a + * text input. Forces use of DirectPipelineRunner for local execution mode. + */ + public static void writeDataToDatastore(Options options) { + // Runs locally via DirectPiplineRunner, as writing is not yet implemented + // for the other runners which is why we just create a PipelineOptions with defaults. + Pipeline p = Pipeline.create(PipelineOptionsFactory.create()); + p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(ParDo.of(new CreateEntityFn(options.getKind()))) + .apply(DatastoreIO.Write.to(options.getDataset())); + + p.run(); + } + + /** + * An example which creates a pipeline to do DatastoreIO.Read from Datastore. + */ + public static void readDataFromDatastore(Options options) { + // Build a query: read all entities of the specified kind. + Query.Builder q = Query.newBuilder(); + q.addKindBuilder().setName(options.getKind()); + Query query = q.build(); + + Pipeline p = Pipeline.create(options); + p.apply(DatastoreIO.Read.named("ReadShakespeareFromDatastore") + .from(options.getDataset(), query)) + .apply(ParDo.of(new GetContentFn())) + .apply(new WordCount.CountWords()) + .apply(TextIO.Write.named("WriteLines").to(options.getOutput())); + + p.run(); + } + + /** + * Main function. + * An example to demo how to use DatastoreIO. The runner here is + * customizable, which means users could pass either DirectPipelineRunner + * or DataflowPipelineRunner in PipelineOptions. + */ + public static void main(String args[]) { + // The options are used in two places, for Dataflow service, and + // building DatastoreIO.Read object + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + if (!options.isReadOnly()) { + // First example: write data to Datastore for reading later. + // Note: this will insert new entries with the given kind. Existing entries + // should be cleared first, or the final counts will contain duplicates. + // The Datastore Admin tool in the AppEngine console can be used to erase + // all entries with a particular kind. + DatastoreWordCount.writeDataToDatastore(options); + } + + // Second example: do parallel read from Datastore. + DatastoreWordCount.readDataFromDatastore(options); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/TfIdf.java b/examples/src/main/java/com/google/cloud/dataflow/examples/TfIdf.java new file mode 100644 index 000000000000..a6bd4f27fd61 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/TfIdf.java @@ -0,0 +1,425 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.URICoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashSet; +import java.util.Set; + +/** + * An example that computes a basic TF-IDF search table for a directory or GCS prefix. + * + *

Command-line usage for this example: + * + *

+ *     java com.google.cloud.dataflow.examples.TfIdf \
+ *       --runner= \
+ *       --input= \
+ *       --output=
+ * 
+ * + *

For example, to execute this pipeline locally to index a local directory: + * + *

+ *     java com.google.cloud.dataflow.examples.TfIdf \
+ *       --runner=DirectPipelineRunner \
+ *       --input= \
+ *       --output=
+ * 
+ * + *

To execute this pipeline using the Dataflow service + * to index the works of Shakespeare and write the results to a GCS bucket: + * (For execution via the Dataflow service, only GCS locations are supported) + * + *

+ *     java com.google.cloud.dataflow.examples.TfIdf \
+ *       --project= \
+ *       --stagingLocation=gs:// \
+ *       --runner=BlockingDataflowPipelineRunner \
+ *       [--input=gs://] \
+ *       --output=gs://
+ * 
+ * + *

The default input is gs://dataflow-samples/shakespeare/ + */ +public class TfIdf { + /** + * Options supported by {@link TfIdf}. + *

+ * Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Path to the directory or GCS prefix containing files to read from") + @Default.String("gs://dataflow-samples/shakespeare/") + String getInput(); + void setInput(String value); + + @Description("Prefix of output URI to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + /** + * Lists documents contained beneath the {@code options.input} prefix/directory. + */ + public static Set listInputDocuments(Options options) + throws URISyntaxException, IOException { + URI baseUri = new URI(options.getInput()); + + // List all documents in the directory or GCS prefix. + URI absoluteUri; + if (baseUri.getScheme() != null) { + absoluteUri = baseUri; + } else { + absoluteUri = new URI( + "file", + baseUri.getAuthority(), + baseUri.getPath(), + baseUri.getQuery(), + baseUri.getFragment()); + } + + Set uris = new HashSet<>(); + if (absoluteUri.getScheme().equals("file")) { + File directory = new File(absoluteUri); + for (String entry : directory.list()) { + File path = new File(directory, entry); + uris.add(path.toURI()); + } + } else if (absoluteUri.getScheme().equals("gs")) { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + URI gcsUriGlob = new URI( + absoluteUri.getScheme(), + absoluteUri.getAuthority(), + absoluteUri.getPath() + "*", + absoluteUri.getQuery(), + absoluteUri.getFragment()); + for (GcsPath entry : gcsUtil.expand(GcsPath.fromUri(gcsUriGlob))) { + uris.add(entry.toUri()); + } + } + + return uris; + } + + /** + * Reads the documents at the provided uris and returns all lines + * from the documents tagged with which document they are from. + */ + public static class ReadDocuments + extends PTransform>> { + + private Iterable uris; + + public ReadDocuments(Iterable uris) { + this.uris = uris; + } + + @Override + public Coder getDefaultOutputCoder() { + return KvCoder.of(URICoder.of(), StringUtf8Coder.of()); + } + + @Override + public PCollection> apply(PInput input) { + Pipeline pipeline = getPipeline(); + + // Create one TextIO.Read transform for each document + // and add its output to a PCollectionList + PCollectionList> urisToLines = + PCollectionList.empty(pipeline); + + // TextIO.Read supports: + // - file: URIs and paths locally + // - gs: URIs on the service + for (final URI uri : uris) { + String uriString; + if (uri.getScheme().equals("file")) { + uriString = new File(uri).getPath(); + } else { + uriString = uri.toString(); + } + + PCollection> oneUriToLines = pipeline + .apply(TextIO.Read.from(uriString) + .named("TextIO.Read(" + uriString + ")")) + .apply(WithKeys.of(uri)); + + urisToLines = urisToLines.and(oneUriToLines); + } + + return urisToLines.apply(Flatten.>create()); + } + } + + /** + * A transform containing a basic TF-IDF pipeline. The input consists of KV objects + * where the key is the document's URI and the value is a piece + * of the document's content. The output is mapping from terms to + * scores for each document URI. + */ + public static class ComputeTfIdf + extends PTransform>, PCollection>>> { + + public ComputeTfIdf() { } + + @Override + public PCollection>> apply( + PCollection> uriToContent) { + + // Compute the total number of documents, and + // prepare this singleton PCollectionView for + // use as a side input. + final PCollectionView totalDocuments = + uriToContent + .apply(Keys.create()) + .apply(RemoveDuplicates.create()) + .apply(Count.globally()) + .apply(View.asSingleton()); + + // Create a collection of pairs mapping a URI to each + // of the words in the document associated with that that URI. + PCollection> uriToWords = uriToContent + .apply(ParDo.named("SplitWords").of( + new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + String line = c.element().getValue(); + for (String word : line.split("\\W+")) { + if (!word.isEmpty()) { + c.output(KV.of(uri, word.toLowerCase())); + } + } + } + })); + + // Compute a mapping from each word to the total + // number of documents in which it appears. + PCollection> wordToDocCount = uriToWords + .apply(RemoveDuplicates.>create()) + .apply(Values.create()) + .apply(Count.perElement()); + + // Compute a mapping from each URI to the total + // number of words in the document associated with that URI. + PCollection> uriToWordTotal = uriToWords + .apply(Keys.create()) + .apply(Count.perElement()); + + // Count, for each (URI, word) pair, the number of + // occurrences of that word in the document associated + // with the URI. + PCollection, Long>> uriAndWordToCount = uriToWords + .apply(Count.>perElement()); + + // Adjust the above collection to a mapping from + // (URI, word) pairs to counts into an isomorphic mapping + // from URI to (word, count) pairs, to prepare for a join + // by the URI key. + PCollection>> uriToWordAndCount = uriAndWordToCount + .apply(ParDo.of(new DoFn, Long>, KV>>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey().getKey(); + String word = c.element().getKey().getValue(); + Long occurrences = c.element().getValue(); + c.output(KV.of(uri, KV.of(word, occurrences))); + } + })); + + // Prepare to join the mapping of URI to (word, count) pairs with + // the mapping of URI to total word counts, by associating + // each of the input PCollection> with + // a tuple tag. Each input must have the same key type, URI + // in this case. The type parameter of the tuple tag matches + // the types of the values for each collection. + final TupleTag wordTotalsTag = new TupleTag(); + final TupleTag> wordCountsTag = new TupleTag>(); + KeyedPCollectionTuple coGbkInput = KeyedPCollectionTuple + .of(wordTotalsTag, uriToWordTotal) + .and(wordCountsTag, uriToWordAndCount); + + // Perform a CoGroupByKey (a sort of pre-join) on the prepared + // inputs. This yields a mapping from URI to a CoGbkResult + // (CoGroupByKey Result). The CoGbkResult is a mapping + // from the above tuple tags to the values in each input + // associated with a particular URI. In this case, each + // KV group a URI with the total number of + // words in that document as well as all the (word, count) + // pairs for particular words. + PCollection> uriToWordAndCountAndTotal = coGbkInput + .apply(CoGroupByKey.create().withName("CoGroupByURI")); + + // Compute a mapping from each word to a (URI, term frequency) + // pair for each URI. A word's term frequency for a document + // is simply the number of times that word occurs in the document + // divided by the total number of words in the document. + PCollection>> wordToUriAndTf = uriToWordAndCountAndTotal + .apply(ParDo.of(new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + Long wordTotal = c.element().getValue().getOnly(wordTotalsTag); + + for (KV wordAndCount : c.element().getValue().getAll(wordCountsTag)) { + String word = wordAndCount.getKey(); + Long wordCount = wordAndCount.getValue(); + Double termFrequency = wordCount.doubleValue() / wordTotal.doubleValue(); + c.output(KV.of(word, KV.of(uri, termFrequency))); + } + } + })); + + // Compute a mapping from each word to its document frequency. + // A word's document frequency in a corpus is the number of + // documents in which the word appears divided by the total + // number of documents in the corpus. Note how the total number of + // documents is passed as a side input; the same value is + // presented to each invocation of the DoFn. + PCollection> wordToDf = wordToDocCount + .apply(ParDo + .withSideInputs(totalDocuments) + .of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Long documentCount = c.element().getValue(); + Long documentTotal = c.sideInput(totalDocuments); + Double documentFrequency = documentCount.doubleValue() + / documentTotal.doubleValue(); + + c.output(KV.of(word, documentFrequency)); + } + })); + + // Join the term frequency and document frequency + // collections, each keyed on the word. + final TupleTag> tfTag = new TupleTag>(); + final TupleTag dfTag = new TupleTag(); + PCollection> wordToUriAndTfAndDf = KeyedPCollectionTuple + .of(tfTag, wordToUriAndTf) + .and(dfTag, wordToDf) + .apply(CoGroupByKey.create()); + + // Compute a mapping from each word to a (URI, TF-IDF) score + // for each URI. There are a variety of definitions of TF-IDF + // ("term frequency - inverse document frequency") score; + // here we use a basic version which is the term frequency + // divided by the log of the document frequency. + PCollection>> wordToUriAndTfIdf = wordToUriAndTfAndDf + .apply(ParDo.of(new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Double df = c.element().getValue().getOnly(dfTag); + + for (KV uriAndTf : c.element().getValue().getAll(tfTag)) { + URI uri = uriAndTf.getKey(); + Double tf = uriAndTf.getValue(); + Double tfIdf = tf * Math.log(1 / df); + c.output(KV.of(word, KV.of(uri, tfIdf))); + } + } + })); + + return wordToUriAndTfIdf; + } + } + + /** + * A {@link PTransform} to write, in CSV format, a mapping from term and URI + * to score. + */ + public static class WriteTfIdf + extends PTransform>>, PDone> { + + private String output; + + public WriteTfIdf(String output) { + this.output = output; + } + + @Override + public PDone apply(PCollection>> wordToUriAndTfIdf) { + return wordToUriAndTfIdf + .apply(ParDo.of(new DoFn>, String>() { + @Override + public void processElement(ProcessContext c) { + c.output(String.format("%s,\t%s,\t%f", + c.element().getKey(), + c.element().getValue().getKey(), + c.element().getValue().getValue())); + } + })) + .apply(TextIO.Write + .to(output) + .withSuffix(".csv")); + } + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline pipeline = Pipeline.create(options); + + pipeline + .apply(new ReadDocuments(listInputDocuments(options))) + .apply(new ComputeTfIdf()) + .apply(new WriteTfIdf(options.getOutput())); + + pipeline.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/TopWikipediaSessions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/TopWikipediaSessions.java new file mode 100644 index 000000000000..baa520ea0447 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/TopWikipediaSessions.java @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableComparator; +import com.google.cloud.dataflow.sdk.transforms.Top; +import com.google.cloud.dataflow.sdk.transforms.windowing.CalendarWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.List; + +/** + * Pipeline that reads Wikipedia edit data from BigQuery and computes the user with + * the longest string of edits separated by no more than an hour within each month. + * + *

This pipeline demonstrates how the Windowing API can be used to perform + * various time-based aggregations of data. + * + *

To run this pipeline, the following options must be provided: + *

{@code
+ *   --project=
+ *   --output=gs://
+ *   --stagingLocation=gs://
+ *   --runner=(Blocking)DataflowPipelineRunner
+ * }
+ * + *

To run this example using Dataflow service, you must additionally + * provide either {@literal --stagingLocation} or {@literal --tempLocation}, and + * select one of the Dataflow pipeline runners, eg + * {@literal --runner=BlockingDataflowPipelineRunner}. + */ +public class TopWikipediaSessions { + private static final String EXPORTED_WIKI_TABLE = "gs://dataflow-samples/wikipedia_edits/*.json"; + + /** + * Extracts user and timestamp from a TableRow representing a Wikipedia edit + */ + static class ExtractUserAndTimestamp extends DoFn { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + int timestamp = (Integer) row.get("timestamp"); + String userName = (String) row.get("contributor_username"); + if (userName != null) { + // Sets the implicit timestamp field to be used in windowing. + c.outputWithTimestamp(userName, new Instant(timestamp * 1000L)); + } + } + } + + /** + * Computes the number of edits in each user session. A session is defined as + * a string of edits where each is separated from the next by less than an hour. + */ + static class ComputeSessions + extends PTransform, PCollection>> { + @Override + public PCollection> apply(PCollection actions) { + return actions + .apply(Window.into(Sessions.withGapDuration(Duration.standardHours(1)))) + + .apply(Count.perElement()); + } + } + + /** + * Computes the longest session ending in each month. + */ + private static class TopPerMonth + extends PTransform>, PCollection>>> { + @Override + public PCollection>> apply(PCollection> sessions) { + return sessions + .apply(Window.>into(CalendarWindows.months(1))) + + .apply(Top.of(1, new SerializableComparator>() { + @Override + public int compare(KV o1, KV o2) { + return Long.compare(o1.getValue(), o2.getValue()); + } + })); + } + } + + static class ComputeTopSessions extends PTransform, PCollection> { + private final double samplingThreshold; + + public ComputeTopSessions(double samplingThreshold) { + this.samplingThreshold = samplingThreshold; + } + + @Override + public PCollection apply(PCollection input) { + return input + .apply(ParDo.of(new ExtractUserAndTimestamp())) + + .apply(ParDo.named("SampleUsers").of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (Math.abs(c.element().hashCode()) <= Integer.MAX_VALUE * samplingThreshold) { + c.output(c.element()); + } + } + })) + + .apply(new ComputeSessions()) + + .apply(ParDo.named("SessionsToStrings").of( + new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of( + c.element().getKey() + " : " + + c.windows().iterator().next(), c.element().getValue())); + } + })) + + .apply(new TopPerMonth()) + + .apply(ParDo.named("FormatOutput").of( + new DoFn>, String>() { + @Override + public void processElement(ProcessContext c) { + for (KV item : c.element()) { + String session = item.getKey(); + long count = item.getValue(); + c.output( + session + " : " + count + " : " + + ((IntervalWindow) c.windows().iterator().next()).start()); + } + } + })); + } + } + + /** + * Options supported by this class. + * + *

Inherits standard Dataflow configuration options. + */ + private static interface Options extends PipelineOptions { + @Description( + "Input specified as a GCS path containing a BigQuery table exported as json") + @Default.String(EXPORTED_WIKI_TABLE) + String getInput(); + void setInput(String value); + + @Description("File to output results to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(Options.class); + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + + Pipeline p = Pipeline.create(dataflowOptions); + + double samplingThreshold = 0.1; + + p.apply(TextIO.Read + .from(options.getInput()) + .withCoder(TableRowJsonCoder.of())) + .apply(new ComputeTopSessions(samplingThreshold)) + .apply(TextIO.Write.named("Write").withoutSharding().to(options.getOutput())); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java new file mode 100644 index 000000000000..96893b909bc7 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java @@ -0,0 +1,174 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * An example that counts words in Shakespeare. For a detailed walkthrough of this + * example see: + * https://developers.google.com/cloud-dataflow/java-sdk/wordcount-example + * + * To execute this pipeline locally, specify general pipeline configuration: + * --project= + * and example configuration: + * --output=[ | gs://] + * + * To execute this pipeline using the Dataflow service, specify pipeline configuration: + * --project= --stagingLocation=gs:// + * --runner=BlockingDataflowPipelineRunner + * and example configuration: + * --output=gs:// + * + * The input file defaults to gs://dataflow-samples/shakespeare/kinglear.txt and can be + * overridden with --input. + */ +public class WordCount { + + /** A DoFn that tokenizes lines of text into individual words. */ + static class ExtractWordsFn extends DoFn { + private Aggregator emptyLines; + + @Override + public void startBundle(Context c) { + emptyLines = c.createAggregator("emptyLines", new Sum.SumLongFn()); + } + + @Override + public void processElement(ProcessContext c) { + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Keep track of the number of lines without any words encountered while tokenizing. + // This aggregator is visible in the monitoring UI when run using DataflowPipelineRunner. + if (words.length == 0) { + emptyLines.addValue(1L); + } + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** A DoFn that converts a Word and Count into a printable string. */ + static class FormatCountsFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + ": " + c.element().getValue()); + } + } + + /** + * A PTransform that converts a PCollection containing lines of text into a PCollection of + * formatted word counts. + *

+ * Although this pipeline fragment could be inlined, bundling it as a PTransform allows for easy + * reuse, modular testing, and an improved monitoring experience. + */ + public static class CountWords extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection lines) { + + // Convert lines of text into individual words. + PCollection words = lines.apply( + ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + PCollection> wordCounts = + words.apply(Count.perElement()); + + // Format each word and count into a printable string. + PCollection results = wordCounts.apply( + ParDo.of(new FormatCountsFn())); + + return results; + } + } + + /** + * Options supported by {@link WordCount}. + *

+ * Inherits standard configuration options. + */ + public static interface Options extends PipelineOptions { + @Description("Path of the file to read from") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInput(); + void setInput(String value); + + @Description("Path of the file to write to") + @Default.InstanceFactory(OutputFactory.class) + String getOutput(); + void setOutput(String value); + + /** Returns gs://${STAGING_LOCATION}/"counts.txt" */ + public static class OutputFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (dataflowOptions.getStagingLocation() != null) { + return GcsPath.fromUri(dataflowOptions.getStagingLocation()) + .resolve("counts.txt").toString(); + } else { + throw new IllegalArgumentException("Must specify --output or --stagingLocation"); + } + } + } + + /** + * By default (numShards == 0), the system will choose the shard count. + * Most programs will not need this option. + */ + @Description("Number of output shards (0 if the system should choose automatically)") + int getNumShards(); + void setNumShards(int value); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(new CountWords()) + .apply(TextIO.Write.named("WriteCounts") + .to(options.getOutput()) + .withNumShards(options.getNumShards())); + + p.run(); + } +} + diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/BigQueryTornadoesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/BigQueryTornadoesTest.java new file mode 100644 index 000000000000..6dafef703648 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/BigQueryTornadoesTest.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.BigQueryTornadoes.ExtractTornadoesFn; +import com.google.cloud.dataflow.examples.BigQueryTornadoes.FormatCountsFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Test case for {@link BigQueryTornadoes}. + */ +@RunWith(JUnit4.class) +public class BigQueryTornadoesTest { + + @Test + public void testExtractTornadoes() throws Exception { + TableRow row = new TableRow() + .set("month", "6") + .set("tornado", true); + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractTornadoesFn()); + Assert.assertThat(extractWordsFn.processBatch(row), + CoreMatchers.hasItems(6)); + } + + @Test + public void testNoTornadoes() throws Exception { + TableRow row = new TableRow() + .set("month", 6) + .set("tornado", false); + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractTornadoesFn()); + Assert.assertTrue(extractWordsFn.processBatch(row).isEmpty()); + } + + @Test + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testFormatCounts() throws Exception { + DoFnTester, TableRow> formatCountsFn = + DoFnTester.of(new FormatCountsFn()); + KV empty[] = {}; + List results = formatCountsFn.processBatch(empty); + Assert.assertTrue(results.size() == 0); + KV input[] = { KV.of(3, 0L), + KV.of(4, Long.MAX_VALUE), + KV.of(5, Long.MIN_VALUE) }; + results = formatCountsFn.processBatch(input); + Assert.assertEquals(results.size(), 3); + Assert.assertEquals(results.get(0).get("month"), 3); + Assert.assertEquals(results.get(0).get("tornado_count"), 0L); + Assert.assertEquals(results.get(1).get("month"), 4); + Assert.assertEquals(results.get(1).get("tornado_count"), Long.MAX_VALUE); + Assert.assertEquals(results.get(2).get("month"), 5); + Assert.assertEquals(results.get(2).get("tornado_count"), Long.MIN_VALUE); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/TfIdfTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/TfIdfTest.java new file mode 100644 index 000000000000..341fd80c25b2 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/TfIdfTest.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.net.URI; +import java.util.Arrays; + +/** + * Tests of TfIdf + */ +@RunWith(JUnit4.class) +public class TfIdfTest { + + /** Test that the example runs */ + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testTfIdf() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection>> wordToUriAndTfIdf = pipeline + .apply(Create.of( + KV.of(new URI("x"), "a b c d"), + KV.of(new URI("y"), "a b c"), + KV.of(new URI("z"), "a m n"))) + .apply(new TfIdf.ComputeTfIdf()); + + PCollection words = wordToUriAndTfIdf + .apply(Keys.create()) + .apply(RemoveDuplicates.create()); + + DataflowAssert.that(words).containsInAnyOrder(Arrays.asList("a", "m", "n", "b", "c", "d")); + + pipeline.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/TopWikipediaSessionsTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/TopWikipediaSessionsTest.java new file mode 100644 index 000000000000..ce43ae9930a4 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/TopWikipediaSessionsTest.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** Unit tests for {@link TopWikipediaSessions}. */ +@RunWith(JUnit4.class) +public class TopWikipediaSessionsTest { + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testComputeTopUsers() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(Arrays.asList( + new TableRow().set("timestamp", 0).set("contributor_username", "user1"), + new TableRow().set("timestamp", 1).set("contributor_username", "user1"), + new TableRow().set("timestamp", 2).set("contributor_username", "user1"), + new TableRow().set("timestamp", 0).set("contributor_username", "user2"), + new TableRow().set("timestamp", 1).set("contributor_username", "user2"), + new TableRow().set("timestamp", 3601).set("contributor_username", "user2"), + new TableRow().set("timestamp", 3602).set("contributor_username", "user2"), + new TableRow().set("timestamp", 35 * 24 * 3600).set("contributor_username", "user3")))) + .apply(new TopWikipediaSessions.ComputeTopSessions(1.0)); + + DataflowAssert.that(output).containsInAnyOrder(Arrays.asList( + "user1 : [1970-01-01T00:00:00.000Z..1970-01-01T01:00:02.000Z)" + + " : 3 : 1970-01-01T00:00:00.000Z", + "user3 : [1970-02-05T00:00:00.000Z..1970-02-05T01:00:00.000Z)" + + " : 1 : 1970-02-01T00:00:00.000Z")); + + p.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java new file mode 100644 index 000000000000..36efec738ddc --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.examples.WordCount.CountWords; +import com.google.cloud.dataflow.examples.WordCount.ExtractWordsFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests of WordCount. + */ +@RunWith(JUnit4.class) +public class WordCountTest { + + /** Example test that tests a specific DoFn. */ + @Test + public void testExtractWordsFn() { + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractWordsFn()); + + Assert.assertThat(extractWordsFn.processBatch(" some input words "), + CoreMatchers.hasItems("some", "input", "words")); + Assert.assertThat(extractWordsFn.processBatch(" "), + CoreMatchers.hasItems()); + Assert.assertThat(extractWordsFn.processBatch(" some ", " input", " words"), + CoreMatchers.hasItems("some", "input", "words")); + } + + static final String[] WORDS_ARRAY = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + static final String[] COUNTS_ARRAY = new String[] { + "hi: 5", "there: 1", "sue: 2", "bob: 2"}; + + /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCountWords() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); + + PCollection output = input.apply(new CountWords()); + + DataflowAssert.that(output).containsInAnyOrder(COUNTS_ARRAY); + p.run(); + } +} diff --git a/pom.xml b/pom.xml new file mode 100644 index 000000000000..fd5b04376e43 --- /dev/null +++ b/pom.xml @@ -0,0 +1,202 @@ + + + + 4.0.0 + + + com.google + google + 5 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + Google Cloud Dataflow Java SDK - Parent + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes the parent POM for other Dataflow + artifacts. + http://cloud.google.com/dataflow + 2013 + + manual_build + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + + Google Inc. + http://www.google.com + + + + + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + + + + 3.0.3 + + + + UTF-8 + + + pom + + sdk + examples + + + + + + + maven-compiler-plugin + 3.1 + + 1.7 + 1.7 + -Xlint:all + true + true + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.5 + + + + org.codehaus.mojo + versions-maven-plugin + 2.1 + + + + org.codehaus.mojo + exec-maven-plugin + 1.1 + + + verify + + java + + + + + + + java.util.logging.config.file + logging.properties + + + + + + + org.apache.felix + maven-bundle-plugin + 2.4.0 + + + + + org.jacoco + jacoco-maven-plugin + 0.7.1.201405082137 + + + + prepare-agent + + + file + true + + + + report + prepare-package + + report + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.15 + + ${testParallelValue} + 4 + + ${project.build.directory}/${project.artifactId}-${project.version}.jar + ${project.build.directory}/${project.artifactId}-${project.version}-tests.jar + + ${testGroups} + + ${runIntegrationTestOnService} + ${dataflowProjectName} + + false + + + + org.apache.maven.surefire + surefire-junit47 + 2.7.2 + + + + + + + + + + + org.codehaus.mojo + versions-maven-plugin + 2.1 + + + + dependency-updates-report + plugin-updates-report + + + + + + + diff --git a/sdk/pom.xml b/sdk/pom.xml new file mode 100644 index 000000000000..93a8f277a837 --- /dev/null +++ b/sdk/pom.xml @@ -0,0 +1,315 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + manual_build + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + Google Cloud Dataflow Java SDK - All + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes entire Dataflow Java SDK. + http://cloud.google.com/dataflow + + manual_build + + jar + + + ${maven.build.timestamp} + yyyy-MM-dd HH:mm + com.google.cloud.dataflow + false + none + + + + + + + DataflowPipelineTests + + true + com.google.cloud.dataflow.sdk.testing.RunnableOnService + both + + + + + + + + src/main/resources + true + + + + + + maven-compiler-plugin + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + ../checkstyle.xml + true + true + false + true + + + + + check + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + dataflow-sdk-compile + compile + + jar + + + + dataflow-sdk-test-compile + test-compile + + test-jar + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + compile + + jar + + + + attach-test-sources + test-compile + + test-jar + + + + + + + org.apache.felix + maven-bundle-plugin + true + + ${project.artifactId}-bundled-${project.version} + + + !${dataflow}.sdk.runners.worker.*, + !${dataflow}.sdk.streaming.*, + !${dataflow}.sdk.util.gcsio, + ${dataflow}.* + + true + *;scope=compile|runtime;inline=true + + + + + + + org.jacoco + jacoco-maven-plugin + + + + + org.apache.avro + avro-maven-plugin + 1.7.7 + + + schemas + generate-test-sources + + schema + + + ${project.basedir}/src/test/ + ${project.build.directory}/generated-test-sources/java + + + + + + + + + + com.google.apis + google-api-services-dataflow + v1beta3-rev1-1.19.0 + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-bigquery + v2-rev167-1.19.0 + + + + com.google.apis + google-api-services-compute + v1-rev34-1.19.0 + + + + com.google.apis + google-api-services-pubsub + v1beta1-rev9-1.19.0 + + + + com.google.apis + google-api-services-storage + v1-rev11-1.19.0 + + + + com.google.http-client + google-http-client-jackson2 + 1.19.0 + + + + com.google.oauth-client + google-oauth-client-java6 + 1.19.0 + + + + com.google.apis + google-api-services-datastore-protobuf + v1beta2-rev1-2.1.0 + + + + com.google.guava + guava + 18.0 + + + + com.fasterxml.jackson.core + jackson-core + 2.4.2 + + + + com.fasterxml.jackson.core + jackson-annotations + 2.4.2 + + + + com.fasterxml.jackson.core + jackson-databind + 2.4.2 + + + + + org.slf4j + slf4j-api + 1.7.7 + + + + org.slf4j + slf4j-jdk14 + 1.7.7 + + + + org.apache.avro + avro + 1.7.7 + + + + joda-time + joda-time + 2.4 + + + + + org.hamcrest + hamcrest-all + 1.3 + test + + + + junit + junit + 4.11 + test + + + + org.mockito + mockito-all + 1.9.5 + test + + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java new file mode 100644 index 000000000000..ec67fd7aabc3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java @@ -0,0 +1,395 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformHierarchy; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A Pipeline manages a DAG of PTransforms, and the PCollections + * that the PTransforms consume and produce. + * + *

After a {@code Pipeline} has been constructed, it can be executed, + * using a default or an explicit {@link PipelineRunner}. + * + *

Multiple {@code Pipeline}s can be constructed and executed independently + * and concurrently. + * + *

Each {@code Pipeline} is self-contained and isolated from any other + * {@code Pipeline}. The {@link PValues} that are inputs and outputs of each of a + * {@code Pipeline}'s {@link PTransform}s are also owned by that {@code Pipeline}. + * A {@code PValue} owned by one {@code Pipeline} can be read only by {@code PTransform}s + * also owned by that {@code Pipeline}. + * + *

Here's a typical example of use: + *

 {@code
+ * // Start by defining the options for the pipeline.
+ * PipelineOptions options = PipelineOptionsFactory.create();
+ * // Then create the pipeline.
+ * Pipeline p = Pipeline.create(options);
+ *
+ * // A root PTransform, like TextIO.Read or Create, gets added
+ * // to the Pipeline by being applied:
+ * PCollection lines =
+ *     p.apply(TextIO.Read.from("gs://bucket/dir/file*.txt"));
+ *
+ * // A Pipeline can have multiple root transforms:
+ * PCollection moreLines =
+ *     p.apply(TextIO.Read.from("gs://bucket/other/dir/file*.txt"));
+ * PCollection yetMoreLines =
+ *     p.apply(Create.of("yet", "more", "lines")).setCoder(StringUtf8Coder.of());
+ *
+ * // Further PTransforms can be applied, in an arbitrary (acyclic) graph.
+ * // Subsequent PTransforms (and intermediate PCollections etc.) are
+ * // implicitly part of the same Pipeline.
+ * PCollection allLines =
+ *     PCollectionList.of(lines).and(moreLines).and(yetMoreLines)
+ *     .apply(new Flatten());
+ * PCollection> wordCounts =
+ *     allLines
+ *     .apply(ParDo.of(new ExtractWords()))
+ *     .apply(new Count());
+ * PCollection formattedWordCounts =
+ *     wordCounts.apply(ParDo.of(new FormatCounts()));
+ * formattedWordCounts.apply(TextIO.Write.to("gs://bucket/dir/counts.txt"));
+ *
+ * // PTransforms aren't executed when they're applied, rather they're
+ * // just added to the Pipeline.  Once the whole Pipeline of PTransforms
+ * // is constructed, the Pipeline's PTransforms can be run using a
+ * // PipelineRunner.  The default PipelineRunner executes the Pipeline
+ * // directly, sequentially, in this one process, which is useful for
+ * // unit tests and simple experiments:
+ * p.run();
+ *
+ * } 
+ */ +public class Pipeline { + private static final Logger LOG = LoggerFactory.getLogger(Pipeline.class); + + ///////////////////////////////////////////////////////////////////////////// + // Public operations. + + /** + * Constructs a pipeline from the provided options. + * + * @return The newly created pipeline. + */ + public static Pipeline create(PipelineOptions options) { + Pipeline pipeline = new Pipeline(PipelineRunner.fromOptions(options), options); + LOG.debug("Creating {}", pipeline); + return pipeline; + } + + /** + * Returns a {@link PBegin} owned by this Pipeline. This is useful + * as the input of a root PTransform such as {@code TextIO.Read} or + * {@link com.google.cloud.dataflow.sdk.transforms.Create}. + */ + public PBegin begin() { + return PBegin.in(this); + } + + /** + * Starts using this pipeline with a root PTransform such as + * {@code TextIO.Read} or + * {@link com.google.cloud.dataflow.sdk.transforms.Create}. + * + *

+ * Alias for {@code begin().apply(root)}. + */ + public Output apply( + PTransform root) { + return begin().apply(root); + } + + /** + * Runs the Pipeline. + */ + public PipelineResult run() { + LOG.debug("Running {} via {}", this, runner); + try { + return runner.run(this); + } catch (UserCodeException e) { + // This serves to replace the stack with one that ends here and + // is caused by the caught UserCodeException, thereby splicing + // out all the stack frames in between the PipelineRunner itself + // and where the worker calls into the user's code. + throw new RuntimeException(e.getCause()); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + // Below here are operations that aren't normally called by users. + + /** + * Returns the {@link CoderRegistry} that this Pipeline uses. + */ + public CoderRegistry getCoderRegistry() { + if (coderRegistry == null) { + coderRegistry = new CoderRegistry(); + coderRegistry.registerStandardCoders(); + } + return coderRegistry; + } + + /** + * Sets the {@link CoderRegistry} that this Pipeline uses. + */ + public void setCoderRegistry(CoderRegistry coderRegistry) { + this.coderRegistry = coderRegistry; + } + + /** + * A PipelineVisitor can be passed into + * {@link Pipeline#traverseTopologically} to be called for each of the + * transforms and values in the Pipeline. + */ + public interface PipelineVisitor { + public void enterCompositeTransform(TransformTreeNode node); + public void leaveCompositeTransform(TransformTreeNode node); + public void visitTransform(TransformTreeNode node); + public void visitValue(PValue value, TransformTreeNode producer); + } + + /** + * Invokes the PipelineVisitor's + * {@link PipelineVisitor#visitTransform} and + * {@link PipelineVisitor#visitValue} operations on each of this + * Pipeline's PTransforms and PValues, in forward + * topological order. + * + *

Traversal of the pipeline causes PTransform and PValue instances to + * be marked as finished, at which point they may no longer be modified. + * + *

Typically invoked by {@link PipelineRunner} subclasses. + */ + public void traverseTopologically(PipelineVisitor visitor) { + Set visitedValues = new HashSet<>(); + // Visit all the transforms, which should implicitly visit all the values. + transforms.visit(visitor, visitedValues); + if (!visitedValues.containsAll(values)) { + throw new RuntimeException( + "internal error: should have visited all the values " + + "after visiting all the transforms"); + } + } + + /** + * Applies the given PTransform to the given Input, + * and returns its Output. + * + *

Called by PInput subclasses in their {@code apply} methods. + */ + public static + Output applyTransform(Input input, + PTransform transform) { + return input.getPipeline().applyInternal(input, transform); + } + + ///////////////////////////////////////////////////////////////////////////// + // Below here are internal operations, never called by users. + + private final PipelineRunner runner; + private final PipelineOptions options; + private final TransformHierarchy transforms = new TransformHierarchy(); + private Collection values = new ArrayList<>(); + private Set usedFullNames = new HashSet<>(); + private CoderRegistry coderRegistry; + + @Deprecated + protected Pipeline(PipelineRunner runner) { + this(runner, PipelineOptionsFactory.create()); + } + + protected Pipeline(PipelineRunner runner, PipelineOptions options) { + this.runner = runner; + this.options = options; + } + + @Override + public String toString() { return "Pipeline#" + hashCode(); } + + /** + * Applies a transformation to the given input. + * + * @see Pipeline#apply + */ + private + Output applyInternal(Input input, + PTransform transform) { + input.finishSpecifying(); + + TransformTreeNode parent = transforms.getCurrent(); + String namePrefix = parent.getFullName(); + String fullName = uniquifyInternal(namePrefix, transform.getName()); + TransformTreeNode child = new TransformTreeNode(parent, transform, fullName, input); + parent.addComposite(child); + + transforms.addInput(child, input); + + transform.setPipeline(this); + LOG.debug("Adding {} to {}", transform, this); + try { + transforms.pushNode(child); + Output output = runner.apply(transform, input); + transforms.setOutput(child, output); + + // recordAsOutput is a NOOP if already called; + output.recordAsOutput(this, child.getTransform()); + verifyOutputState(output, child); + return output; + } finally { + transforms.popNode(); + } + } + + /** + * Returns all producing transforms for the {@link PValue}s contained + * in {@code output}. + */ + private List getProducingTransforms(POutput output) { + List producingTransforms = new ArrayList<>(); + for (PValue value : output.expand()) { + PTransform transform = value.getProducingTransformInternal(); + if (transform != null) { + producingTransforms.add(transform); + } + } + return producingTransforms; + } + + /** + * Verifies that the output of a PTransform is correctly defined. + * + *

A non-composite transform must have all + * of its outputs registered as produced by the transform. + */ + private void verifyOutputState(POutput output, TransformTreeNode node) { + if (!node.isCompositeNode()) { + PTransform thisTransform = node.getTransform(); + List producingTransforms = getProducingTransforms(output); + for (PTransform producingTransform : producingTransforms) { + if (thisTransform != producingTransform) { + throw new IllegalArgumentException("Output of non-composite transform " + + thisTransform + " is registered as being produced by" + + " a different transform: " + producingTransform); + } + } + } + } + + /** + * Returns the configured pipeline runner. + */ + public PipelineRunner getRunner() { + return runner; + } + + /** + * Returns the configured pipeline options. + */ + public PipelineOptions getOptions() { + return options; + } + + /** + * Returns the output associated with a transform. + * + * @throws IllegalStateException if the transform has not been applied to the pipeline. + */ + public POutput getOutput(PTransform transform) { + TransformTreeNode node = transforms.getNode(transform); + Preconditions.checkState(node != null, + "Unknown transform: " + transform); + return node.getOutput(); + } + + /** + * Returns the input associated with a transform. + * + * @throws IllegalStateException if the transform has not been applied to the pipeline. + */ + public PInput getInput(PTransform transform) { + TransformTreeNode node = transforms.getNode(transform); + Preconditions.checkState(node != null, + "Unknown transform: " + transform); + return node.getInput(); + } + + /** + * Returns the fully qualified name of a transform. + * + * @throws IllegalStateException if the transform has not been applied to the pipeline. + */ + public String getFullName(PTransform transform) { + TransformTreeNode node = transforms.getNode(transform); + Preconditions.checkState(node != null, + "Unknown transform: " + transform); + return node.getFullName(); + } + + /** + * Returns a unique name for a transform with the given prefix (from + * enclosing transforms) and initial name. + * + *

For internal use only. + */ + private String uniquifyInternal(String namePrefix, String origName) { + String name = origName; + int suffixNum = 2; + while (true) { + String candidate = namePrefix.isEmpty() ? name : namePrefix + "/" + name; + if (usedFullNames.add(candidate)) { + return candidate; + } + // A duplicate! Retry. + name = origName + suffixNum++; + } + } + + /** + * Adds the given PValue to this Pipeline. + * + *

For internal use only. + */ + public void addValueInternal(PValue value) { + this.values.add(value); + value.setPipelineInternal(this); + LOG.debug("Adding {} to {}", value, this); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java new file mode 100644 index 000000000000..7ab3845724f2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +/** + * Result of {@link com.google.cloud.dataflow.sdk.Pipeline#run()}. + */ +public interface PipelineResult { + + // TODO: method to ask if pipeline is running / finished. + // TODO: method to retrieve error messages. + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java new file mode 100644 index 000000000000..6d032371207f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.util.Collections; +import java.util.List; + +/** + * An AtomicCoder is one that has no component Coders or other state. + * All instances of its class are equal. + * + * @param the type of the values being transcoded + */ +public abstract class AtomicCoder extends StandardCoder { + protected AtomicCoder() {} + + @Override + public List> getCoderArguments() { return null; } + + /** + * Returns a list of values contained in the provided example + * value, one per type parameter. If there are no type parameters, + * returns the empty list. + */ + public static List getInstanceComponents(T exampleValue) { + return Collections.emptyList(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java new file mode 100644 index 000000000000..5ea631a970a7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java @@ -0,0 +1,202 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.util.CloudObject; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.avro.reflect.ReflectDatumWriter; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * An encoder using Avro binary format. + *

+ * The Avro schema is generated using reflection on the element type, using + * Avro's + * org.apache.avro.reflect.ReflectData, + * and encoded as part of the {@code Coder} instance. + *

+ * For complete details about schema generation and how it can be controlled please see + * the + * org.apache.avro.reflect package. + * Only concrete classes with a no-argument constructor can be mapped to Avro records. + * All inherited fields that are not static or transient are used. Fields are not permitted to be + * null unless annotated by + * + * org.apache.avro.reflect.Nullable or a + * + * org.apache.avro.reflect.Union containing null. + *

+ * To use, specify the {@code Coder} type on a PCollection: + *

+ * {@code
+ * PCollection records =
+ *     input.apply(...)
+ *          .setCoder(AvroCoder.of(MyCustomElement.class);
+ * }
+ * 
+ *

+ * or annotate the element class using {@code @DefaultCoder}. + *


+ * {@literal @}DefaultCoder(AvroCoder.class)
+ * public class MyCustomElement {
+ *   ...
+ * }
+ * 
+ * + * @param the type of elements handled by this coder + */ +public class AvroCoder extends StandardCoder { + + /** + * Returns an {@code AvroCoder} instance for the provided element type. + * @param the element type + */ + public static AvroCoder of(Class type) { + return new AvroCoder<>(type, ReflectData.get().getSchema(type)); + } + + /** + * Returns an {@code AvroCoder} instance for the Avro schema. The implicit + * type is GenericRecord. + */ + public static AvroCoder of(Schema schema) { + return new AvroCoder<>(GenericRecord.class, schema); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element type + * using the provided Avro schema. + * + *

If the type argument is GenericRecord, the schema may be arbitrary. + * Otherwise, the schema must correspond to the type provided. + * + * @param the element type + */ + public static AvroCoder of(Class type, Schema schema) { + return new AvroCoder<>(type, schema); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @JsonCreator + public static AvroCoder of( + @JsonProperty("type") String classType, + @JsonProperty("schema") String schema) throws ClassNotFoundException { + Schema.Parser parser = new Schema.Parser(); + return new AvroCoder(Class.forName(classType), parser.parse(schema)); + } + + private final Class type; + private final Schema schema; + private final DatumWriter writer; + private final DatumReader reader; + private final EncoderFactory encoderFactory = new EncoderFactory(); + private final DecoderFactory decoderFactory = new DecoderFactory(); + + protected AvroCoder(Class type, Schema schema) { + this.type = type; + this.schema = schema; + this.reader = createDatumReader(); + this.writer = createDatumWriter(); + } + + @Override + public void encode(T value, OutputStream outStream, Context context) + throws IOException { + BinaryEncoder encoder = encoderFactory.directBinaryEncoder(outStream, null); + writer.write(value, encoder); + encoder.flush(); + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + BinaryDecoder decoder = decoderFactory.directBinaryDecoder(inStream, null); + return reader.read(null, decoder); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addString(result, "type", type.getName()); + addString(result, "schema", schema.toString()); + return result; + } + + /** + * Depends upon the structure being serialized. + */ + @Override + public boolean isDeterministic() { + return false; + } + + /** + * Returns a new DatumReader that can be used to read from + * an Avro file directly. + */ + public DatumReader createDatumReader() { + if (type.equals(GenericRecord.class)) { + return new GenericDatumReader<>(schema); + } else { + return new ReflectDatumReader<>(schema); + } + } + + /** + * Returns a new DatumWriter that can be used to write to + * an Avro file directly. + */ + public DatumWriter createDatumWriter() { + if (type.equals(GenericRecord.class)) { + return new GenericDatumWriter<>(schema); + } else { + return new ReflectDatumWriter<>(schema); + } + } + + /** + * Returns the schema used by this coder. + */ + public Schema getSchema() { + return schema; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java new file mode 100644 index 000000000000..6af2d6f5ac4e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A BigEndianIntegerCoder encodes Integers in 4 bytes, big-endian. + */ +public class BigEndianIntegerCoder extends AtomicCoder { + @JsonCreator + public static BigEndianIntegerCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final BigEndianIntegerCoder INSTANCE = + new BigEndianIntegerCoder(); + + private BigEndianIntegerCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + new DataOutputStream(outStream).writeInt(value); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readInt(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Integer value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Integer value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + return 4; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java new file mode 100644 index 000000000000..43ee9cab34be --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A BigEndianLongCoder encodes Longs in 8 bytes, big-endian. + */ +public class BigEndianLongCoder extends AtomicCoder { + @JsonCreator + public static BigEndianLongCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final BigEndianLongCoder INSTANCE = new BigEndianLongCoder(); + + private BigEndianLongCoder() {} + + @Override + public void encode(Long value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + new DataOutputStream(outStream).writeLong(value); + } + + @Override + public Long decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readLong(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Long value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Long value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + return 8; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java new file mode 100644 index 000000000000..c750d932dd06 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.io.ByteStreams; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A ByteArrayCoder encodes byte[] objects. + * + * If in a nested context, prefixes the encoded array with a VarInt encoding + * of the length. + */ +public class ByteArrayCoder extends AtomicCoder { + @JsonCreator + public static ByteArrayCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final ByteArrayCoder INSTANCE = new ByteArrayCoder(); + + private ByteArrayCoder() {} + + @Override + public void encode(byte[] value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null byte[]"); + } + if (!context.isWholeStream) { + VarInt.encode(value.length, outStream); + } + outStream.write(value); + } + + @Override + public byte[] decode(InputStream inStream, Context context) + throws IOException, CoderException { + if (context.isWholeStream) { + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + ByteStreams.copy(inStream, outStream); + return outStream.toByteArray(); + } else { + int length = VarInt.decodeInt(inStream); + if (length < 0) { + throw new IOException("invalid length " + length); + } + byte[] value = new byte[length]; + ByteStreams.readFully(inStream, value); + return value; + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(byte[] value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(byte[] value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null byte[]"); + } + long size = 0; + if (!context.isWholeStream) { + size += VarInt.getLength(value.length); + } + return size + value.length; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java new file mode 100644 index 000000000000..3760cb82003b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.List; + +/** + * A {@code Coder} defines how to encode and decode values of type {@code T} into byte streams. + * + *

All methods of a {@code Coder} are required to be thread safe. + * + *

{@code Coder}s are serialized during job creation and deserialized + * before use, via JSON serialization. + * + *

See {@link SerializableCoder} for an example of a {@code Coder} that adds + * a custom field to the {@code Coder} serialization. It provides a + * constructor annotated with {@link + * com.fasterxml.jackson.annotation.JsonCreator}, which is a factory method + * used when deserializing a {@code Coder} instance. + * + *

See {@link KvCoder} for an example of a nested {@code Coder} type. + * + * @param the type of the values being transcoded + */ +public interface Coder extends Serializable { + /** The context in which encoding or decoding is being done. */ + public static class Context { + /** + * The outer context. The value being encoded or decoded takes + * up the remainder of the whole record/stream contents. + */ + public static final Context OUTER = new Context(true); + + /** + * The nested context. The value being encoded or decoded is + * (potentially) a part of a larger record/stream contents, and + * may have other parts encoded or decoded after it. + */ + public static final Context NESTED = new Context(false); + + /** + * Whether the encoded or decoded value fills the remainder of the + * output or input (resp.) record/stream contents. If so, then + * the size of the decoded value can be determined from the + * remaining size of the record/stream contents, and so explicit + * lengths aren't required. + */ + public final boolean isWholeStream; + + public Context(boolean isWholeStream) { + this.isWholeStream = isWholeStream; + } + + public Context nested() { + return NESTED; + } + } + + /** + * Encodes the given value of type {@code T} onto the given output stream + * in the given context. + * + * @throws IOException if writing to the {@code OutputStream} fails + * for some reason + * @throws CoderException if the value could not be encoded for some reason + */ + public void encode(T value, OutputStream outStream, Context context) + throws CoderException, IOException; + + /** + * Decodes a value of type {@code T} from the given input stream in + * the given context. Returns the decoded value. + * + * @throws IOException if reading from the {@code InputStream} fails + * for some reason + * @throws CoderException if the value could not be decoded for some reason + */ + public T decode(InputStream inStream, Context context) + throws CoderException, IOException; + + /** + * If this is a {@code Coder} for a parameterized type, returns the + * list of {@code Coder}s being used for each of the parameters, or + * returns {@code null} if this cannot be done or this is not a + * parameterized type. + */ + public List> getCoderArguments(); + + /** + * Returns the {@link CloudObject} that represents this {@code Coder}. + */ + public CloudObject asCloudObject(); + + /** + * Returns true if the coding is deterministic. + * + *

In order for a {@code Coder} to be considered deterministic, + * the following must be true: + *

    + *
  • two values which compare as equal (via {@code Object.equals()} + * or {@code Comparable.compareTo()}, if supported), have the same + * encoding. + *
  • the {@code Coder} always produces a canonical encoding, which is the + * same for an instance of an object even if produced on different + * computers at different times. + *
+ */ + public boolean isDeterministic(); + + /** + * Returns whether {@link #registerByteSizeObserver} cheap enough to + * call for every element, that is, if this {@code Coder} can + * calculate the byte size of the element to be coded in roughly + * constant time (or lazily). + * + *

Not intended to be called by user code, but instead by + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * implementations. + */ + public boolean isRegisterByteSizeObserverCheap(T value, Context context); + + /** + * Notifies the {@code ElementByteSizeObserver} about the byte size + * of the encoded value using this {@code Coder}. + * + *

Not intended to be called by user code, but instead by + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * implementations. + */ + public void registerByteSizeObserver( + T value, ElementByteSizeObserver observer, Context context) + throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java new file mode 100644 index 000000000000..1bbc3fa176b7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.io.IOException; + +/** + * A CoderException is thrown if there is a problem encoding or + * decoding a value. + */ +public class CoderException extends IOException { + public CoderException(String message) { + super(message); + } + + public CoderException(String message, Throwable cause) { + super(message, cause); + } + + public CoderException(Throwable cause) { + super(cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java new file mode 100644 index 000000000000..670b4e3e320a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java @@ -0,0 +1,701 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.reflect.TypeToken; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A CoderRegistry allows registering the default Coder to use for a Java class, + * and looking up and instantiating the default Coder for a Java type. + * + *

{@code CoderRegistry} uses the following mechanisms to determine a + * default {@link Coder} for a Java class, in order of precedence: + *

    + *
  • Registration: coders can be registered explicitly via + * {@link #registerCoder}. Built-in types are registered via + * {@link #registerStandardCoders()}. + *
  • Annotations: {@link DefaultCoder} can be used to annotate a type with + * the default {@code Coder} type. + *
  • Inheritance: {@link Serializable} objects are given a default + * {@code Coder} of {@link SerializableCoder}. + *
+ */ +public class CoderRegistry { + private static final Logger LOG = LoggerFactory.getLogger(CoderRegistry.class); + + /** A factory for default Coders for values of a particular class. */ + public abstract static class CoderFactory { + /** + * Returns the default Coder to use for values of a particular type, + * given the Coders for each of the type's generic parameter types. + * May return null if no default Coder can be created. + */ + public abstract Coder create( + List> typeArgumentCoders); + + /** + * Returns a list of objects contained in {@code value}, one per + * type argument, or {@code null} if none can be determined. + */ + public abstract List getInstanceComponents(Object value); + } + + /** A factory that always returns the coder with which it is instantiated. */ + public class ConstantCoderFactory extends CoderFactory { + private Coder coder; + + public ConstantCoderFactory(Coder coder) { + this.coder = coder; + } + + @Override + public Coder create(List> typeArgumentCoders) { + return this.coder; + } + + @Override + public List getInstanceComponents(Object value) { + return Collections.emptyList(); + } + } + + public CoderRegistry() {} + + /** + * Registers standard Coders with this CoderRegistry. + */ + public void registerStandardCoders() { + registerCoder(Double.class, DoubleCoder.class); + registerCoder(Instant.class, InstantCoder.class); + registerCoder(Integer.class, VarIntCoder.class); + registerCoder(Iterable.class, IterableCoder.class); + registerCoder(KV.class, KvCoder.class); + registerCoder(List.class, ListCoder.class); + registerCoder(Long.class, VarLongCoder.class); + registerCoder(String.class, StringUtf8Coder.class); + registerCoder(TableRow.class, TableRowJsonCoder.class); + registerCoder(Void.class, VoidCoder.class); + registerCoder(byte[].class, ByteArrayCoder.class); + registerCoder(URI.class, URICoder.class); + registerCoder(TimestampedValue.class, TimestampedValue.TimestampedValueCoder.class); + } + + /** + * Registers {@code coderClazz} as the default {@code Coder} + * class to handle encoding and decoding instances of {@code clazz} + * of type {@code T}. + * + *

{@code coderClazz} should have a static factory method with the + * following signature: + * + *

 {@code
+   * public static Coder of(Coder argCoder1, Coder argCoder2, ...)
+   * } 
+ * + *

This method will be called to create instances of {@code Coder} + * for values of type {@code T}, passing Coders for each of the generic type + * parameters of {@code T}. If {@code T} takes no generic type parameters, + * then the {@code of()} factory method should have no arguments. + * + *

If {@code T} is a parameterized type, then it should additionally + * have a method with the following signature: + * + *

 {@code
+   * public static List getInstanceComponents(T exampleValue);
+   * } 
+   *
+   * 

This method will be called to decompose a value during the coder + * inference process, to automatically choose coders for the components + */ + public void registerCoder(Class clazz, + Class coderClazz) { + int numTypeParameters = clazz.getTypeParameters().length; + + // Find the static factory method of coderClazz named 'of' with + // the appropriate number of type parameters. + + Class[] factoryMethodArgTypes = new Class[numTypeParameters]; + Arrays.fill(factoryMethodArgTypes, Coder.class); + + Method factoryMethod; + try { + factoryMethod = + coderClazz.getDeclaredMethod("of", factoryMethodArgTypes); + } catch (NoSuchMethodException | SecurityException exn) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "does not have an accessible method named 'of' with " + + numTypeParameters + " arguments of Coder type", + exn); + } + if (!Modifier.isStatic(factoryMethod.getModifiers())) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type is not static"); + } + if (!coderClazz.isAssignableFrom(factoryMethod.getReturnType())) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type does not return a " + coderClazz); + } + try { + if (!factoryMethod.isAccessible()) { + factoryMethod.setAccessible(true); + } + } catch (SecurityException exn) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type is not accessible", + exn); + } + + // Find the static method to decompose values when inferring a coder, + // if there are type parameters for which we also need an example + // value + Method getComponentsMethod = null; + if (clazz.getTypeParameters().length > 0) { + try { + getComponentsMethod = coderClazz.getDeclaredMethod( + "getInstanceComponents", + clazz); + } catch (NoSuchMethodException | SecurityException exn) { + LOG.warn("cannot find getInstanceComponents for class {}. This may limit the ability to" + + " infer a Coder for values of this type.", coderClazz, exn); + } + } + + registerCoder(clazz, defaultCoderFactory(coderClazz, factoryMethod, getComponentsMethod)); + } + + public void registerCoder(Class rawClazz, + CoderFactory coderFactory) { + if (coderFactoryMap.put(rawClazz, coderFactory) != null) { + throw new IllegalArgumentException( + "cannot register multiple default Coder factories for " + rawClazz); + } + } + + public void registerCoder(Class rawClazz, Coder coder) { + CoderFactory factory = new ConstantCoderFactory(coder); + registerCoder(rawClazz, factory); + } + + /** + * Returns the Coder to use by default for values of the given type, + * or null if there is no default Coder. + */ + public Coder getDefaultCoder(TypeToken typeToken) { + return getDefaultCoder(typeToken, Collections.>emptyMap()); + } + + /** + * Returns the Coder to use by default for values of the given type, + * where the given context type uses the given context coder, + * or null if there is no default Coder. + */ + public Coder getDefaultCoder(TypeToken typeToken, + TypeToken contextTypeToken, + Coder contextCoder) { + return getDefaultCoder(typeToken, + createTypeBindings(contextTypeToken, contextCoder)); + } + + /** + * Returns the Coder to use on elements produced by this function, given + * the coder used for its input elements. + */ + public Coder getDefaultOutputCoder( + SerializableFunction fn, Coder inputCoder) { + return getDefaultCoder( + fn.getClass(), SerializableFunction.class, inputCoder); + } + + /** + * Returns the Coder to use for the last type parameter specialization + * of the subclass given Coders to use for all other type parameters + * specializations (if any). + */ + public Coder getDefaultCoder( + Class subClass, + Class baseClass, + Coder... knownCoders) { + Coder[] allCoders = new Coder[knownCoders.length + 1]; + // Last entry intentionally left null. + System.arraycopy(knownCoders, 0, allCoders, 0, knownCoders.length); + allCoders = getDefaultCoders(subClass, baseClass, allCoders); + @SuppressWarnings("unchecked") // trusted + Coder coder = (Coder) allCoders[knownCoders.length]; + return coder; + } + + /** + * Returns the Coder to use for the specified type parameter specialization + * of the subclass, given Coders to use for all other type parameters + * (if any). + */ + @SuppressWarnings("unchecked") + public Coder getDefaultCoder( + Class subClass, + Class baseClass, + Map> knownCoders, + String paramName) { + // TODO: Don't infer unneeded params. + return (Coder) getDefaultCoders(subClass, baseClass, knownCoders) + .get(paramName); + } + + /** + * Returns the Coder to use for the provided example value, if it can + * be determined, otherwise returns {@code null}. If more than one + * default coder matches, this will raise an exception. + */ + public Coder getDefaultCoder(T exampleValue) { + Class clazz = exampleValue.getClass(); + + if (clazz.getTypeParameters().length == 0) { + // Trust that getDefaultCoder returns a valid + // Coder for non-generic clazz. + @SuppressWarnings("unchecked") + Coder coder = (Coder) getDefaultCoder(clazz); + return coder; + } else { + CoderFactory factory = getDefaultCoderFactory(clazz); + if (factory == null) { + return null; + } + + List components = factory.getInstanceComponents(exampleValue); + if (components == null) { + return null; + } + + // componentcoders = components.map(this.getDefaultCoder) + List> componentCoders = new ArrayList<>(); + for (Object component : components) { + Coder componentCoder = getDefaultCoder(component); + if (componentCoder == null) { + return null; + } else { + componentCoders.add(componentCoder); + } + } + + // Trust that factory.create maps from valid component coders + // to a valid Coder. + @SuppressWarnings("unchecked") + Coder coder = (Coder) factory.create(componentCoders); + return coder; + } + } + + + /** + * Returns a Map from each of baseClass's type parameters to the Coder to + * use by default for it, in the context of subClass's specialization of + * baseClass. + * + *

For example, if baseClass is Map.class and subClass extends + * {@code Map} then this will return the registered Coders + * to use for String and Integer as a {"K": stringCoder, "V": intCoder} Map. + * The knownCoders parameter can be used to provide known coders for any of + * the parameters which will be used to infer the others. + * + * @param subClass the concrete type whose specializations are being inferred + * @param baseClass the base type, a parameterized class + * @param knownCoders a map corresponding to the set of known coders indexed + * by parameter name + */ + public Map> getDefaultCoders( + Class subClass, + Class baseClass, + Map> knownCoders) { + TypeVariable>[] typeParams = baseClass.getTypeParameters(); + Coder[] knownCodersArray = new Coder[typeParams.length]; + for (int i = 0; i < typeParams.length; i++) { + knownCodersArray[i] = knownCoders.get(typeParams[i].getName()); + } + Coder[] resultArray = getDefaultCoders( + subClass, baseClass, knownCodersArray); + Map> result = new HashMap<>(); + for (int i = 0; i < typeParams.length; i++) { + result.put(typeParams[i].getName(), resultArray[i]); + } + return result; + } + + /** + * Returns an array listing, for each of baseClass's type parameters, the + * Coder to use by default for it, in the context of subClass's specialization + * of baseClass. + * + *

For example, if baseClass is Map.class and subClass extends + * {@code Map} then this will return the registered Coders + * to use for String and Integer in that order. The knownCoders parameter + * can be used to provide known coders for any of the parameters which will + * be used to infer the others. + * + *

If a type cannot be inferred, null is returned. + * + * @param subClass the concrete type whose specializations are being inferred + * @param baseClass the base type, a parameterized class + * @param knownCoders an array corresponding to the set of base class + * type parameters. Each entry is can be either a Coder (in which + * case it will be used for inference) or null (in which case it + * will be inferred). May be null to indicate the entire set of + * parameters should be inferred. + * @throws IllegalArgumentException if baseClass doesn't have type parameters + * or if the length of knownCoders is not equal to the number of type + * parameters + */ + public Coder[] getDefaultCoders( + Class subClass, + Class baseClass, + Coder[] knownCoders) { + Type type = TypeToken.of(subClass).getSupertype(baseClass).getType(); + if (!(type instanceof ParameterizedType)) { + throw new IllegalArgumentException(type + " is not a ParameterizedType"); + } + ParameterizedType parameterizedType = (ParameterizedType) type; + Type[] typeArgs = parameterizedType.getActualTypeArguments(); + if (knownCoders == null) { + knownCoders = new Coder[typeArgs.length]; + } else if (typeArgs.length != knownCoders.length) { + throw new IllegalArgumentException( + "Class " + baseClass + " has " + typeArgs.length + " parameters, " + + "but " + knownCoders.length + " coders are requested."); + } + Map> context = new HashMap<>(); + for (int i = 0; i < knownCoders.length; i++) { + if (knownCoders[i] != null) { + if (!isCompatible(knownCoders[i], typeArgs[i])) { + throw new IllegalArgumentException( + "Cannot encode elements of type " + typeArgs[i] + + " with " + knownCoders[i]); + } + context.put(typeArgs[i], knownCoders[i]); + } + } + Coder[] result = new Coder[typeArgs.length]; + for (int i = 0; i < knownCoders.length; i++) { + if (knownCoders[i] != null) { + result[i] = knownCoders[i]; + } else { + result[i] = getDefaultCoder(typeArgs[i], context); + } + } + return result; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns whether the given coder can possibly encode elements + * of the given type. + */ + static boolean isCompatible(Coder coder, Type type) { + Type coderType = + ((ParameterizedType) + TypeToken.of(coder.getClass()).getSupertype(Coder.class).getType()) + .getActualTypeArguments()[0]; + if (type instanceof TypeVariable) { + return true; // Can't rule it out. + } + Class coderClass = TypeToken.of(coderType).getRawType(); + if (!coderClass.isAssignableFrom(TypeToken.of(type).getRawType())) { + return false; + } + if (coderType instanceof ParameterizedType + && !isNullOrEmpty(coder.getCoderArguments())) { + @SuppressWarnings("unchecked") + Type[] typeArguments = + ((ParameterizedType) + TypeToken.of(type).getSupertype((Class) coderClass).getType()) + .getActualTypeArguments(); + List> typeArgumentCoders = coder.getCoderArguments(); + assert typeArguments.length == typeArgumentCoders.size(); + for (int i = 0; i < typeArguments.length; i++) { + if (!isCompatible( + typeArgumentCoders.get(i), + TypeToken.of(type).resolveType(typeArguments[i]).getType())) { + return false; + } + } + } + return true; // For all we can tell. + } + + private static boolean isNullOrEmpty(Collection c) { + return c == null || c.size() == 0; + } + + /** + * The map of classes to the CoderFactories to use to create their + * default Coders. + */ + Map, CoderFactory> coderFactoryMap = new HashMap<>(); + + /** + * Returns a CoderFactory that invokes the given static factory method + * to create the Coder. + */ + static CoderFactory defaultCoderFactory( + final Class coderClazz, + final Method coderFactoryMethod, + final Method getComponentsMethod) { + + return new CoderFactory() { + @Override + public Coder create(List> typeArgumentCoders) { + try { + return (Coder) coderFactoryMethod.invoke( + null /* static */, typeArgumentCoders.toArray()); + } catch (IllegalAccessException | + IllegalArgumentException | + InvocationTargetException | + NullPointerException | + ExceptionInInitializerError exn) { + throw new IllegalStateException( + "error when invoking Coder factory method " + coderFactoryMethod, + exn); + } + } + + @Override + public List getInstanceComponents(Object value) { + if (getComponentsMethod == null) { + throw new IllegalStateException( + "no suitable static getInstanceComponents method available for " + + "Coder " + coderClazz); + } + + try { + @SuppressWarnings("unchecked") + List result = (List) (getComponentsMethod.invoke( + null /* static */, value)); + return result; + } catch (IllegalAccessException + | IllegalArgumentException + | InvocationTargetException + | NullPointerException + | ExceptionInInitializerError exn) { + throw new IllegalStateException( + "error when invoking Coder getComponents method " + getComponentsMethod, + exn); + } + } + }; + } + + static CoderFactory defaultCoderFactory(Class coderClazz, final Method coderFactoryMethod) { + return defaultCoderFactory(coderClazz, coderFactoryMethod, null); + } + + /** + * Returns the CoderFactory to use to create default Coders for + * instances of the given class, or null if there is no default + * CoderFactory registered. + */ + CoderFactory getDefaultCoderFactory(Class clazz) { + CoderFactory coderFactory = coderFactoryMap.get(clazz); + if (coderFactory == null) { + LOG.debug("No Coder registered for {}", clazz); + } + return coderFactory; + } + + /** + * Returns the Coder to use by default for values of the given type, + * in a context where the given types use the given coders, + * or null if there is no default Coder. + */ + Coder getDefaultCoder(TypeToken typeToken, + Map> typeCoderBindings) { + Coder defaultCoder = getDefaultCoder(typeToken.getType(), + typeCoderBindings); + LOG.debug("Default Coder for {}: {}", typeToken, defaultCoder); + @SuppressWarnings("unchecked") + Coder result = (Coder) defaultCoder; + return result; + } + + /** + * Returns the Coder to use by default for values of the given type, + * in a context where the given types use the given coders, + * or null if there is no default Coder. + */ + Coder getDefaultCoder(Type type, Map> typeCoderBindings) { + Coder coder = typeCoderBindings.get(type); + if (coder != null) { + return coder; + } + if (type instanceof Class) { + return getDefaultCoder((Class) type); + } else if (type instanceof ParameterizedType) { + return this.getDefaultCoder((ParameterizedType) type, + typeCoderBindings); + } else if (type instanceof TypeVariable + || type instanceof WildcardType) { + // No default coder for an unknown generic type. + LOG.debug("No Coder for unknown generic type {}", type); + return null; + } else { + throw new RuntimeException( + "internal error: unexpected kind of Type: " + type); + } + } + + /** + * Returns the Coder to use by default for values of the given + * class, or null if there is no default Coder. + */ + Coder getDefaultCoder(Class clazz) { + CoderFactory coderFactory = getDefaultCoderFactory(clazz); + if (coderFactory != null) { + LOG.debug("Default Coder for {} found by factory", clazz); + return coderFactory.create(Collections.>emptyList()); + } + + DefaultCoder defaultAnnotation = clazz.getAnnotation( + DefaultCoder.class); + if (defaultAnnotation != null) { + LOG.debug("Default Coder for {} found by DefaultCoder annotation", clazz); + return InstanceBuilder.ofType(Coder.class) + .fromClass(defaultAnnotation.value()) + .fromFactoryMethod("of") + .withArg(Class.class, clazz) + .build(); + } + + // Interface-based defaults. + if (Serializable.class.isAssignableFrom(clazz)) { + @SuppressWarnings("unchecked") + Class serializableClazz = + (Class) clazz; + LOG.debug("Default Coder for {}: SerializableCoder", serializableClazz); + return SerializableCoder.of(serializableClazz); + } + + LOG.debug("No default Coder for {}", clazz); + return null; + } + + /** + * Returns the Coder to use by default for values of the given + * parameterized type, in a context where the given types use the + * given coders, or null if there is no default Coder. + */ + Coder getDefaultCoder( + ParameterizedType type, + Map> typeCoderBindings) { + Class rawClazz = (Class) type.getRawType(); + CoderFactory coderFactory = getDefaultCoderFactory(rawClazz); + if (coderFactory == null) { + return null; + } + List> typeArgumentCoders = new ArrayList<>(); + for (Type typeArgument : type.getActualTypeArguments()) { + Coder typeArgumentCoder = getDefaultCoder(typeArgument, + typeCoderBindings); + if (typeArgumentCoder == null) { + return null; + } + typeArgumentCoders.add(typeArgumentCoder); + } + return coderFactory.create(typeArgumentCoders); + } + + /** + * Returns a Map where each of the type variables embedded in the + * given type are mapped to the corresponding Coders in the given + * coder. + */ + Map> createTypeBindings(TypeToken typeToken, + Coder coder) { + Map> typeCoderBindings = new HashMap<>(); + fillTypeBindings(typeToken.getType(), coder, typeCoderBindings); + return typeCoderBindings; + } + + /** + * Adds to the given map bindings from each of the type variables + * embedded in the given type to the corresponding Coders in the + * given coder. + */ + void fillTypeBindings(Type type, + Coder coder, + Map> typeCoderBindings) { + if (type instanceof TypeVariable) { + LOG.debug("Binding type {} to Coder {}", type, coder); + typeCoderBindings.put(type, coder); + } else if (type instanceof ParameterizedType) { + fillTypeBindings((ParameterizedType) type, + coder, + typeCoderBindings); + } + } + + /** + * Adds to the given map bindings from each of the type variables + * embedded in the given parameterized type to the corresponding + * Coders in the given coder. + */ + void fillTypeBindings(ParameterizedType type, + Coder coder, + Map> typeCoderBindings) { + Type[] typeArguments = type.getActualTypeArguments(); + List> coderArguments = coder.getCoderArguments(); + if (coderArguments == null + || typeArguments.length != coderArguments.size()) { + return; + } + for (int i = 0; i < typeArguments.length; i++) { + fillTypeBindings(typeArguments[i], + coderArguments.get(i), + typeCoderBindings); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java new file mode 100644 index 000000000000..546695dfefe8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collection; +import java.util.List; + +/** + * A CollectionCoder encodes Collections. + * + * @param the type of the elements of the Collections being transcoded + */ +public class CollectionCoder extends IterableLikeCoder> { + + public static CollectionCoder of(Coder elemCoder) { + return new CollectionCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + @JsonCreator + public static CollectionCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + /** + * Returns the first element in this collection if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Collection exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + CollectionCoder(Coder elemCoder) { + super(elemCoder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java new file mode 100644 index 000000000000..6b31297a1071 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.StringUtils; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; + +/** + * An abstract base class for writing Coders that encodes itself via java + * serialization. Subclasses only need to implement the {@link Coder#encode} + * and {@link Coder#decode} methods. + * + *

+ * Not to be confused with {@link SerializableCoder} that encodes serializables. + * + * @param the type of elements handled by this coder + */ +public abstract class CustomCoder extends AtomicCoder + implements Serializable { + + @JsonCreator + public static CustomCoder of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data, The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty("type") String type, + @JsonProperty("serialized_coder") String serializedCoder) { + return (CustomCoder) SerializableUtils.deserializeFromByteArray( + StringUtils.jsonStringToByteArray(serializedCoder), + type); + } + + @Override + public CloudObject asCloudObject() { + // N.B. We use the CustomCoder class, not the derived class, since during + // deserialization we will be using the CustomCoder's static factory method + // to construct an instance of the derived class. + CloudObject result = CloudObject.forClass(CustomCoder.class); + addString(result, "type", getClass().getName()); + addString(result, "serialized_coder", + StringUtils.byteArrayToJsonString( + SerializableUtils.serializeToByteArray(this))); + return result; + } + + @Override + public boolean isDeterministic() { + return false; + } + + // This coder inherits isRegisterByteSizeObserverCheap, + // getEncodedElementByteSize and registerByteSizeObserver + // from StandardCoder. Override if we can do better. +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java new file mode 100644 index 000000000000..6c6f4197c5a8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Specifies a default {@link Coder} class to handle encoding and decoding + * instances of the annotated class. + * + *

The specified {@code Coder} must implement a function with the following + * signature: + *

{@code
+ * public static Coder of(Class clazz) {...}
+ * }
+ * + *

For example, to configure the use of Java serialization as the default + * for a class, annotate the class to use + * {@link com.google.cloud.dataflow.sdk.coders.SerializableCoder} as follows:the + * + *


+ * {@literal @}DefaultCoder(SerializableCoder.class)
+ * public class MyCustomDataType {
+ *   // ...
+ * }
+ * 
+ * + *

Similarly, to configure the use of + * {@link com.google.cloud.dataflow.sdk.coders.AvroCoder} as the default: + *


+ * {@literal @}DefaultCoder(AvroCoder.class)
+ * public class MyCustomDataType {
+ *   public MyCustomDataType() {}   // Avro requires an empty constructor.
+ *   // ...
+ * }
+ * 
+ * + *

Coders specified explicitly via + * {@link com.google.cloud.dataflow.sdk.values.PCollection#setCoder(Coder) + * PCollection.setCoder} + * take precedence, followed by Coders registered at runtime via + * {@link CoderRegistry#registerCoder}. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +@SuppressWarnings("rawtypes") +public @interface DefaultCoder { + Class value(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java new file mode 100644 index 000000000000..6b531ad0dc45 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A DoubleCoder encodes Doubles in 8 bytes. + */ +public class DoubleCoder extends AtomicCoder { + @JsonCreator + public static DoubleCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final DoubleCoder INSTANCE = new DoubleCoder(); + + private DoubleCoder() {} + + @Override + public void encode(Double value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Double"); + } + new DataOutputStream(outStream).writeDouble(value); + } + + @Override + public Double decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readDouble(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * Floating-point operations are not guaranteed to be deterministic, even + * if the storage format might be, so floating point representations are not + * recommended for use in operations which require deterministic inputs. + */ + @Override + public boolean isDeterministic() { + return false; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Double value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Double value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Double"); + } + return 8; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java new file mode 100644 index 000000000000..988a04c03160 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.datastore.DatastoreV1.Entity; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * An EntityCoder encodes/decodes Datastore Entity objects. + */ +public class EntityCoder extends AtomicCoder { + + @JsonCreator + public static EntityCoder of() { + return INSTANCE; + } + + /***************************/ + + private static final EntityCoder INSTANCE = new EntityCoder(); + + private EntityCoder() {} + + @Override + public void encode(Entity value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Entity"); + } + + // Since Entity implements com.google.protobuf.MessageLite, + // we could directly use writeTo to write to a OutputStream object + outStream.write(java.nio.ByteBuffer.allocate(4).putInt(value.getSerializedSize()).array()); + value.writeTo(outStream); + outStream.flush(); + } + + @Override + public Entity decode(InputStream inStream, Context context) + throws IOException { + byte[] entitySize = new byte[4]; + inStream.read(entitySize, 0, 4); + int size = java.nio.ByteBuffer.wrap(entitySize).getInt(); + byte[] data = new byte[size]; + inStream.read(data, 0, size); + return Entity.parseFrom(data); + } + + @Override + protected long getEncodedElementByteSize(Entity value, Context context) + throws Exception { + return value.getSerializedSize(); + } + + /** + * A datastore kind can hold arbitrary Object instances, + * which makes the encoding non-deterministic. + */ + @Override + public boolean isDeterministic() { + return false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java new file mode 100644 index 000000000000..319012439170 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Instant; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A InstantCoder encodes joda Instant. + */ +public class InstantCoder extends AtomicCoder { + @JsonCreator + public static InstantCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final InstantCoder INSTANCE = new InstantCoder(); + + private InstantCoder() {} + + @Override + public void encode(Instant value, OutputStream outStream, Context context) + throws CoderException, IOException { + // Shift the millis by Long.MIN_VALUE so that negative values sort before positive + // values when encoded. The overflow is well-defined: + // http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.18.2 + BigEndianLongCoder.of().encode(value.getMillis() - Long.MIN_VALUE, outStream, context); + } + + @Override + public Instant decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new Instant(BigEndianLongCoder.of().decode(inStream, context) + Long.MIN_VALUE); + } + + @Override + public boolean isDeterministic() { + return true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java new file mode 100644 index 000000000000..801dd2042cfd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * An IterableCoder encodes Iterables. + * + * @param the type of the elements of the Iterables being transcoded + */ +public class IterableCoder extends IterableLikeCoder> { + + public static IterableCoder of(Coder elemCoder) { + return new IterableCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + @JsonCreator + public static IterableCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + /** + * Returns the first element in this iterable if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Iterable exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + IterableCoder(Coder elemCoder) { + super(elemCoder); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_STREAM_LIKE, true); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java new file mode 100644 index 000000000000..e6ecdbe26bb9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java @@ -0,0 +1,227 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservableIterable; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Observable; +import java.util.Observer; + +/** + * The base class of Coders for Iterable subclasses. + * + * @param the type of the elements of the Iterables being transcoded + * @param the type of the Iterables being transcoded + */ +public abstract class IterableLikeCoder> + extends StandardCoder { + + public Coder getElemCoder() { return elemCoder; } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + final Coder elemCoder; + + /** + * Returns the first element in this iterable-like if it is non-empty, + * otherwise returns {@code null}. + */ + protected static > + List getInstanceComponentsHelper( + IT exampleValue) { + for (T value : exampleValue) { + return Arrays.asList(value); + } + return null; + } + + protected IterableLikeCoder(Coder elemCoder) { + this.elemCoder = elemCoder; + } + + @Override + public void encode(IT iterable, OutputStream outStream, Context context) + throws IOException, CoderException { + if (iterable == null) { + throw new CoderException("cannot encode a null Iterable"); + } + Context nestedContext = context.nested(); + DataOutputStream dataOutStream = new DataOutputStream(outStream); + if (iterable instanceof Collection) { + // We can know the size of the Iterable. Use an encoding with a + // leading size field, followed by that many elements. + Collection collection = (Collection) iterable; + dataOutStream.writeInt(collection.size()); + for (T elem : collection) { + elemCoder.encode(elem, dataOutStream, nestedContext); + } + } else { + // We don't know the size without traversing it. So use a + // "hasNext" sentinel before each element. + // TODO: Don't use the sentinel if context.isWholeStream. + dataOutStream.writeInt(-1); + for (T elem : iterable) { + dataOutStream.writeBoolean(true); + elemCoder.encode(elem, dataOutStream, nestedContext); + } + dataOutStream.writeBoolean(false); + } + // Make sure all our output gets pushed to the underlying outStream. + dataOutStream.flush(); + } + + @Override + public IT decode(InputStream inStream, Context context) + throws IOException, CoderException { + Context nestedContext = context.nested(); + DataInputStream dataInStream = new DataInputStream(inStream); + int size = dataInStream.readInt(); + if (size >= 0) { + List elements = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + elements.add(elemCoder.decode(dataInStream, nestedContext)); + } + return (IT) elements; + } else { + // We don't know the size a priori. Check if we're done with + // each element. + List elements = new ArrayList<>(); + while (dataInStream.readBoolean()) { + elements.add(elemCoder.decode(dataInStream, nestedContext)); + } + return (IT) elements; + } + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(elemCoder); + } + + /** + * Encoding is not deterministic for the general Iterable case, as it depends + * upon the type of iterable. This may allow two objects to compare as equal + * while the encoding differs. + */ + @Override + public boolean isDeterministic() { + return false; + } + + /** + * Returns whether iterable can use lazy counting, since that + * requires minimal extra computation. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(IT iterable, Context context) { + return iterable instanceof ElementByteSizeObservableIterable; + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + IT iterable, ElementByteSizeObserver observer, Context context) + throws Exception { + if (iterable == null) { + throw new CoderException("cannot encode a null Iterable"); + } + Context nestedContext = context.nested(); + + if (iterable instanceof ElementByteSizeObservableIterable) { + observer.setLazy(); + ElementByteSizeObservableIterable observableIT = + (ElementByteSizeObservableIterable) iterable; + observableIT.addObserver( + new IteratorObserver(observer, iterable instanceof Collection)); + } else { + if (iterable instanceof Collection) { + // We can know the size of the Iterable. Use an encoding with a + // leading size field, followed by that many elements. + Collection collection = (Collection) iterable; + observer.update(4L); + for (T elem : collection) { + elemCoder.registerByteSizeObserver(elem, observer, nestedContext); + } + } else { + // We don't know the size without traversing it. So use a + // "hasNext" sentinel before each element. + // TODO: Don't use the sentinel if context.isWholeStream. + observer.update(4L); + for (T elem : iterable) { + observer.update(1L); + elemCoder.registerByteSizeObserver(elem, observer, nestedContext); + } + observer.update(1L); + } + } + } + + /** + * An observer that gets notified when an observable iterator + * returns a new value. This observer just notifies an outerObserver + * about this event. Additionally, the outerObserver is notified + * about additional separators that are transparently added by this + * coder. + */ + private class IteratorObserver implements Observer { + private final ElementByteSizeObserver outerObserver; + private final boolean countable; + + public IteratorObserver(ElementByteSizeObserver outerObserver, + boolean countable) { + this.outerObserver = outerObserver; + this.countable = countable; + + if (countable) { + // Additional 4 bytes are due to size. + outerObserver.update(4L); + } else { + // Additional 5 bytes are due to size = -1 (4 bytes) and + // hasNext = false (1 byte). + outerObserver.update(5L); + } + } + + @Override + public void update(Observable obs, Object obj) { + if (!(obj instanceof Long)) { + throw new AssertionError("unexpected parameter object"); + } + + if (countable) { + outerObserver.update(obs, obj); + } else { + // Additional 1 byte is due to hasNext = true flag. + outerObserver.update(obs, 1 + (long) obj); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java new file mode 100644 index 000000000000..000d6ca75807 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; + +/** + * A KvCoder encodes KVs. + * + * @param the type of the keys of the KVs being transcoded + * @param the type of the values of the KVs being transcoded + */ +public class KvCoder extends KvCoderBase> { + + public static KvCoder of(Coder keyCoder, + Coder valueCoder) { + return new KvCoder<>(keyCoder, valueCoder); + } + + @JsonCreator + public static KvCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + return of(components.get(0), components.get(1)); + } + + public static List getInstanceComponents( + KV exampleValue) { + return Arrays.asList( + exampleValue.getKey(), + exampleValue.getValue()); + } + + public Coder getKeyCoder() { return keyCoder; } + public Coder getValueCoder() { return valueCoder; } + + ///////////////////////////////////////////////////////////////////////////// + + Coder keyCoder; + Coder valueCoder; + + KvCoder(Coder keyCoder, Coder valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public void encode(KV kv, OutputStream outStream, Context context) + throws IOException, CoderException { + if (kv == null) { + throw new CoderException("cannot encode a null KV"); + } + Context nestedContext = context.nested(); + keyCoder.encode(kv.getKey(), outStream, nestedContext); + valueCoder.encode(kv.getValue(), outStream, nestedContext); + } + + @Override + public KV decode(InputStream inStream, Context context) + throws IOException, CoderException { + Context nestedContext = context.nested(); + K key = keyCoder.decode(inStream, nestedContext); + V value = valueCoder.decode(inStream, nestedContext); + return KV.of(key, value); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(keyCoder, valueCoder); + } + + @Override + public boolean isDeterministic() { + return getKeyCoder().isDeterministic() && getValueCoder().isDeterministic(); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_PAIR_LIKE, true); + return result; + } + + /** + * Returns whether both keyCoder and valueCoder are considered not expensive. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(KV kv, Context context) { + return keyCoder.isRegisterByteSizeObserverCheap(kv.getKey(), + context.nested()) + && valueCoder.isRegisterByteSizeObserverCheap(kv.getValue(), + context.nested()); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + KV kv, ElementByteSizeObserver observer, Context context) + throws Exception { + if (kv == null) { + throw new CoderException("cannot encode a null KV"); + } + keyCoder.registerByteSizeObserver( + kv.getKey(), observer, context.nested()); + valueCoder.registerByteSizeObserver( + kv.getValue(), observer, context.nested()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java new file mode 100644 index 000000000000..b959e1c3c576 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A abstract base class for KvCoder. Works around a Jackson2 bug tickled when building + * {@link KvCoder} directly (as of this writing, Jackson2 walks off the end of + * an array when it tries to deserialize a class with multiple generic type + * parameters). This class should be removed when possible. + * @param the type of values being transcoded + */ +public abstract class KvCoderBase extends StandardCoder { + @JsonCreator + public static KvCoderBase of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data. The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty(value = PropertyNames.IS_PAIR_LIKE, required = false) boolean isPairLike, + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + return KvCoder.of(components); + } + + protected KvCoderBase() {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java new file mode 100644 index 000000000000..ab9d8147aa1f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A ListCoder encodes Lists. + * + * @param the type of the elements of the Lists being transcoded + */ +public class ListCoder extends IterableLikeCoder> { + + public static ListCoder of(Coder elemCoder) { + return new ListCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + @JsonCreator + public static ListCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + /** + * Returns the first element in this list if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents(List exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + ListCoder(Coder elemCoder) { + super(elemCoder); + } + + /** + * List sizes are always known, so ListIterable may be deterministic while + * the general IterableLikeCoder is not. + */ + @Override + public boolean isDeterministic() { + return getElemCoder().isDeterministic(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java new file mode 100644 index 000000000000..fa3fc5895015 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * A MapCoder encodes Maps. + * + * @param the type of the keys of the KVs being transcoded + * @param the type of the values of the KVs being transcoded + */ +public class MapCoder extends MapCoderBase> { + + /** + * Produces a MapCoder with the given keyCoder and valueCoder. + */ + public static MapCoder of( + Coder keyCoder, + Coder valueCoder) { + return new MapCoder<>(keyCoder, valueCoder); + } + + @JsonCreator + public static MapCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + return of((Coder) components.get(0), (Coder) components.get(1)); + } + + /** + * Returns the key and value for an arbitrary element of this map, + * if it is non-empty, otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Map exampleValue) { + for (Map.Entry entry : exampleValue.entrySet()) { + return Arrays.asList(entry.getKey(), entry.getValue()); + } + return null; + } + + public Coder getKeyCoder() { return keyCoder; } + public Coder getValueCoder() { return valueCoder; } + + ///////////////////////////////////////////////////////////////////////////// + + Coder keyCoder; + Coder valueCoder; + + MapCoder(Coder keyCoder, Coder valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public void encode( + Map map, + OutputStream outStream, + Context context) + throws IOException, CoderException { + DataOutputStream dataOutStream = new DataOutputStream(outStream); + dataOutStream.writeInt(map.size()); + for (Entry entry : map.entrySet()) { + keyCoder.encode(entry.getKey(), outStream, context.nested()); + valueCoder.encode(entry.getValue(), outStream, context.nested()); + } + dataOutStream.flush(); + } + + @Override + public Map decode(InputStream inStream, Context context) + throws IOException, CoderException { + DataInputStream dataInStream = new DataInputStream(inStream); + int size = dataInStream.readInt(); + Map retval = new HashMap<>(); + for (int i = 0; i < size; ++i) { + K key = keyCoder.decode(inStream, context.nested()); + V value = valueCoder.decode(inStream, context.nested()); + retval.put(key, value); + } + return retval; + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(keyCoder, valueCoder); + } + + /** + * Not all maps have a deterministic encoding. + * + *

For example, HashMap comparison does not depend on element order, so + * two HashMap instances may be equal but produce different encodings. + */ + @Override + public boolean isDeterministic() { + return false; + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + Map map, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(4L); + for (Entry entry : map.entrySet()) { + keyCoder.registerByteSizeObserver( + entry.getKey(), observer, context.nested()); + valueCoder.registerByteSizeObserver( + entry.getValue(), observer, context.nested()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java new file mode 100644 index 000000000000..e896e0d36dc1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A abstract base class for MapCoder. Works around a Jackson2 bug tickled when building + * {@link MapCoder} directly (as of this writing, Jackson2 walks off the end of + * an array when it tries to deserialize a class with multiple generic type + * parameters). This should be removed in favor of a better workaround. + * @param the type of values being transcoded + */ +public abstract class MapCoderBase extends StandardCoder { + @JsonCreator + public static MapCoderBase of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data, The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + return MapCoder.of(components); + } + + protected MapCoderBase() {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java new file mode 100644 index 000000000000..c078e6629a2b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CloudObject; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * An encoder of {@link java.io.Serializable} objects. + * + * To use, specify the coder type on a PCollection. + *

+ * {@code
+ *   PCollection records =
+ *       foo.apply(...).setCoder(SerializableCoder.of(MyRecord.class));
+ * }
+ * 
+ * + *

SerializableCoder does not guarantee a deterministic encoding, as Java + * Serialization may produce different binary encodings for two equivalent + * objects. + * + * @param the type of elements handled by this coder + */ +public class SerializableCoder + extends AtomicCoder { + /** + * Returns a {@code SerializableCoder} instance for the provided element type. + * @param the element type + */ + public static SerializableCoder of(Class type) { + return new SerializableCoder<>(type); + } + + @JsonCreator + public static SerializableCoder of(@JsonProperty("type") String classType) + throws ClassNotFoundException { + Class clazz = Class.forName(classType); + if (!Serializable.class.isAssignableFrom(clazz)) { + throw new ClassNotFoundException( + "Class " + classType + " does not implement Serializable"); + } + return of((Class) clazz); + } + + private final Class type; + + protected SerializableCoder(Class type) { + this.type = type; + } + + public Class getRecordType() { + return type; + } + + @Override + public void encode(T value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null record"); + } + try (ObjectOutputStream oos = new ObjectOutputStream(outStream)) { + oos.writeObject(value); + } catch (IOException exn) { + throw new CoderException("unable to serialize record " + value, exn); + } + } + + @Override + public T decode(InputStream inStream, Context context) + throws IOException, CoderException { + try (ObjectInputStream ois = new ObjectInputStream(inStream)) { + return type.cast(ois.readObject()); + } catch (ClassNotFoundException e) { + throw new CoderException("unable to deserialize record", e); + } + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + result.put("type", type.getName()); + return result; + } + + @Override + public boolean isDeterministic() { + return false; + } + + @Override + public boolean equals(Object other) { + if (getClass() != other.getClass()) { + return false; + } + return type == ((SerializableCoder) other).type; + } + + // This coder inherits isRegisterByteSizeObserverCheap, + // getEncodedElementByteSize and registerByteSizeObserver + // from StandardCoder. Looks like we cannot do much better + // in this case. +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java new file mode 100644 index 000000000000..1a234c7b40ed --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A SetCoder encodes Sets. + * + * @param the type of the elements of the set + */ +public class SetCoder extends StandardCoder> { + + /** + * Produces a SetCoder with the given elementCoder. + */ + public static SetCoder of(Coder elementCoder) { + return new SetCoder<>(elementCoder); + } + + @JsonCreator + public static SetCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + public Coder getElementCoder() { return elementCoder; } + + ///////////////////////////////////////////////////////////////////////////// + + Coder elementCoder; + + SetCoder(Coder elementCoder) { + this.elementCoder = elementCoder; + } + + @Override + public void encode( + Set set, + OutputStream outStream, + Context context) + throws IOException, CoderException { + DataOutputStream dataOutStream = new DataOutputStream(outStream); + dataOutStream.writeInt(set.size()); + for (T element : set) { + elementCoder.encode(element, outStream, context.nested()); + } + dataOutStream.flush(); + } + + @Override + public Set decode(InputStream inStream, Context context) + throws IOException, CoderException { + DataInputStream dataInStream = new DataInputStream(inStream); + int size = dataInStream.readInt(); + Set retval = new HashSet(); + for (int i = 0; i < size; ++i) { + T element = elementCoder.decode(inStream, context.nested()); + retval.add(element); + } + return retval; + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(elementCoder); + } + + /** + * Not all sets have a deterministic encoding. + * + *

For example, HashSet comparison does not depend on element order, so + * two HashSet instances may be equal but produce different encodings. + */ + @Override + public boolean isDeterministic() { + return false; + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + Set set, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(4L); + for (T element : set) { + elementCoder.registerByteSizeObserver(element, observer, context.nested()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java new file mode 100644 index 000000000000..7a35fdcafbf2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A StandardCoder is one that defines equality, hashing, and printing + * via the class name and recursively using {@link #getComponents}. + * + * @param the type of the values being transcoded + */ +public abstract class StandardCoder implements Coder { + + protected StandardCoder() {} + + /** + * Returns the list of {@code Coder}s that are components of this + * {@code Coder}. Returns an empty list if this is an {@link AtomicCoder} (or + * other {@code Coder} with no components). + */ + public List> getComponents() { + List> coderArguments = getCoderArguments(); + if (coderArguments == null) { + return Collections.emptyList(); + } else { + return coderArguments; + } + } + + @Override + public boolean equals(Object o) { + if (this.getClass() != o.getClass()) { + return false; + } + StandardCoder that = (StandardCoder) o; + return this.getComponents().equals(that.getComponents()); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + getComponents().hashCode(); + } + + @Override + public String toString() { + String s = getClass().getName(); + s = s.substring(s.lastIndexOf('.') + 1); + List> componentCoders = getComponents(); + if (!componentCoders.isEmpty()) { + s += "("; + boolean first = true; + for (Coder componentCoder : componentCoders) { + if (first) { + first = false; + } else { + s += ", "; + } + s += componentCoder.toString(); + } + s += ")"; + } + return s; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + + List> components = getComponents(); + if (!components.isEmpty()) { + List cloudComponents = new ArrayList<>(components.size()); + for (Coder coder : components) { + cloudComponents.add(coder.asCloudObject()); + } + addList(result, PropertyNames.COMPONENT_ENCODINGS, cloudComponents); + } + + return result; + } + + /** + * StandardCoder requires elements to be fully encoded and copied + * into a byte stream to determine the byte size of the element, + * which is considered expensive. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(T value, Context context) { + return false; + } + + /** + * Returns the size in bytes of the encoded value using this + * coder. Derived classes override this method if byte size can be + * computed with less computation or copying. + */ + protected long getEncodedElementByteSize(T value, Context context) + throws Exception { + try { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + encode(value, os, context); + return os.size(); + } catch (Exception exn) { + throw new IllegalArgumentException( + "Unable to encode element " + value + " with coder " + this, exn); + } + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. Calls + * getEncodedElementByteSize() and notifies ElementByteSizeObserver. + */ + @Override + public void registerByteSizeObserver( + T value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(getEncodedElementByteSize(value, context)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java new file mode 100644 index 000000000000..17995c31b65b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.io.ByteStreams; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; +import java.nio.charset.Charset; + +/** + * A StringUtf8Coder encodes Java Strings in UTF-8 encoding. + * If in a nested context, prefixes the string with a VarInt length field. + */ +public class StringUtf8Coder extends AtomicCoder { + @JsonCreator + public static StringUtf8Coder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final StringUtf8Coder INSTANCE = new StringUtf8Coder(); + + private static class Singletons { + private static final Charset UTF8 = Charset.forName("UTF-8"); + } + + // Writes a string with VarInt size prefix, supporting large strings. + private static void writeString(String value, DataOutputStream dos) + throws IOException { + byte[] bytes = value.getBytes(Singletons.UTF8); + VarInt.encode(bytes.length, dos); + dos.write(bytes); + } + + // Reads a string with VarInt size prefix, supporting large strings. + private static String readString(DataInputStream dis) throws IOException { + int len = VarInt.decodeInt(dis); + if (len < 0) { + throw new CoderException("Invalid encoded string length: " + len); + } + byte[] bytes = new byte[len]; + dis.readFully(bytes); + return new String(bytes, Singletons.UTF8); + } + + private StringUtf8Coder() {} + + @Override + public void encode(String value, OutputStream outStream, Context context) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null String"); + } + if (context.isWholeStream) { + outStream.write(value.getBytes(Singletons.UTF8)); + } else { + writeString(value, new DataOutputStream(outStream)); + } + } + + @Override + public String decode(InputStream inStream, Context context) + throws IOException { + if (context.isWholeStream) { + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + ByteStreams.copy(inStream, outStream); + // ByteArrayOutputStream.toString provides no Charset overloads. + return outStream.toString("UTF-8"); + } else { + try { + return readString(new DataInputStream(inStream)); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + protected long getEncodedElementByteSize(String value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null String"); + } + if (context.isWholeStream) { + return value.getBytes(Singletons.UTF8).length; + } else { + DataOutputStream stream = new DataOutputStream(new ByteArrayOutputStream()); + writeString(value, stream); + return stream.size(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java new file mode 100644 index 000000000000..e49dfbb9c01c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.bigquery.model.TableRow; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A TableRowJsonCoder encodes BigQuery TableRow objects. + */ +public class TableRowJsonCoder extends AtomicCoder { + + @JsonCreator + public static TableRowJsonCoder of() { + return INSTANCE; + } + + @Override + public void encode(TableRow value, OutputStream outStream, Context context) + throws IOException { + String strValue = MAPPER.writeValueAsString(value); + StringUtf8Coder.of().encode(strValue, outStream, context); + } + + @Override + public TableRow decode(InputStream inStream, Context context) + throws IOException { + String strValue = StringUtf8Coder.of().decode(inStream, context); + return MAPPER.readValue(strValue, TableRow.class); + } + + @Override + protected long getEncodedElementByteSize(TableRow value, Context context) + throws Exception { + String strValue = MAPPER.writeValueAsString(value); + return StringUtf8Coder.of().getEncodedElementByteSize(strValue, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + // FAIL_ON_EMPTY_BEANS is disabled in order to handle null values in + // TableRow. + private static final ObjectMapper MAPPER = + new ObjectMapper().disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); + + private static final TableRowJsonCoder INSTANCE = new TableRowJsonCoder(); + + private TableRowJsonCoder() { + } + + /** + * TableCell can hold arbitrary Object instances, which makes the encoding + * non-deterministic. + */ + @Override + public boolean isDeterministic() { + return false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java new file mode 100644 index 000000000000..93d080b7f01c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A TextualIntegerCoder encodes Integers as text. + */ +public class TextualIntegerCoder extends AtomicCoder { + @JsonCreator + public static TextualIntegerCoder of() { + return new TextualIntegerCoder(); + } + + + ///////////////////////////////////////////////////////////////////////////// + + private TextualIntegerCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + String textualValue = value.toString(); + StringUtf8Coder.of().encode(textualValue, outStream, context); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + String textualValue = StringUtf8Coder.of().decode(inStream, context); + try { + return Integer.valueOf(textualValue); + } catch (NumberFormatException exn) { + throw new CoderException("error when decoding a textual integer", exn); + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + protected long getEncodedElementByteSize(Integer value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + String textualValue = value.toString(); + return StringUtf8Coder.of().getEncodedElementByteSize(textualValue, context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/URICoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/URICoder.java new file mode 100644 index 000000000000..ed5ae45c53e7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/URICoder.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.URISyntaxException; + +/** + * A {@code URICoder} encodes/decodes {@link URI}s by conversion to/from {@link String}, delegating + * encoding/decoding of the string to {@link StringUtf8Coder}. + */ +public class URICoder extends AtomicCoder { + + @JsonCreator + public static URICoder of() { + return INSTANCE; + } + + private static final URICoder INSTANCE = new URICoder(); + private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); + + private URICoder() {} + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public void encode(URI value, OutputStream outStream, Context context) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null URI"); + } + STRING_CODER.encode(value.toString(), outStream, context); + } + + @Override + public URI decode(InputStream inStream, Context context) + throws IOException { + try { + return new URI(STRING_CODER.decode(inStream, context)); + } catch (URISyntaxException exn) { + throw new CoderException(exn); + } + } + + @Override + public boolean isDeterministic() { + return STRING_CODER.isDeterministic(); + } + + @Override + protected long getEncodedElementByteSize(URI value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null URI"); + } + return STRING_CODER.getEncodedElementByteSize(value.toString(), context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java new file mode 100644 index 000000000000..eff03fb73732 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A VarIntCoder encodes Integers using between 1 and 5 bytes. Negative + * numbers always take 5 bytes, so BigEndianIntegerCoder may be preferable for + * ints that are known to often be large or negative. + */ +public class VarIntCoder extends AtomicCoder { + @JsonCreator + public static VarIntCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final VarIntCoder INSTANCE = + new VarIntCoder(); + + private VarIntCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + VarInt.encode(value.intValue(), outStream); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return VarInt.decodeInt(inStream); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Integer value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Integer value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + return VarInt.getLength(value.longValue()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java new file mode 100644 index 000000000000..74f9b6092288 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A VarLongCoder encodes longs using between 1 and 10 bytes. Negative + * numbers always take 10 bytes, so BigEndianLongCoder may be preferable for + * longs that are known to often be large or negative. + */ +public class VarLongCoder extends AtomicCoder { + @JsonCreator + public static VarLongCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final VarLongCoder INSTANCE = + new VarLongCoder(); + + private VarLongCoder() {} + + @Override + public void encode(Long value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + VarInt.encode(value.longValue(), outStream); + } + + @Override + public Long decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return VarInt.decodeLong(inStream); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Long value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Long value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + return VarInt.getLength(value.longValue()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java new file mode 100644 index 000000000000..fc9a1e0958b2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A VoidCoder encodes Voids. Uses zero bytes per Void. + */ +public class VoidCoder extends AtomicCoder { + @JsonCreator + public static VoidCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final VoidCoder INSTANCE = new VoidCoder(); + + private VoidCoder() {} + + @Override + public void encode(Void value, OutputStream outStream, Context context) { + // Nothing to write! + } + + @Override + public Void decode(InputStream inStream, Context context) { + // Nothing to read! + return null; + } + + @Override + public boolean isDeterministic() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Void value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Void value, Context context) + throws Exception { + return 0; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java new file mode 100644 index 000000000000..ea305e776bc9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.coders.Coder}s + * to specify how data is encoded to and decoded from byte strings. + * + *

During execution of a Pipeline, elements in a + * {@link com.google.cloud.dataflow.sdk.values.PCollection} + * may need to be encoded into byte strings. + * This happens both at the beginning and end of a pipeline when data is read from and written to + * persistent storage and also during execution of a pipeline when elements are communicated between + * machines. + * + *

Exactly when PCollection elements are encoded during execution depends on which + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} is being used and how that runner + * chooses to execute the pipeline. As such, Dataflow requires that all PCollections have an + * appropriate Coder in case it becomes necessary. In many cases, the Coder can be inferred from + * the available Java type + * information and the Pipeline's {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry}. It + * can be specified per PCollection via + * {@link com.google.cloud.dataflow.sdk.values.PCollection#setCoder(Coder)} or per type using the + * {@link com.google.cloud.dataflow.sdk.coders.DefaultCoder} annotation. + * + *

This package provides a number of coders for common types like {@code Integer}, + * {@code String}, and {@code List}, as well as coders like + * {@link com.google.cloud.dataflow.sdk.coders.AvroCoder} that can be used to encode many custom + * types. + * + */ +package com.google.cloud.dataflow.sdk.coders; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java new file mode 100644 index 000000000000..7a9e6ea3d394 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java @@ -0,0 +1,678 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.ValueWithMetadata; +import static com.google.cloud.dataflow.sdk.util.CloudSourceUtils.readElemsFromSource; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.worker.AvroSink; +import com.google.cloud.dataflow.sdk.runners.worker.AvroSource; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.ReflectData; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Transforms for reading and writing Avro files. + * + *

To read a {@link PCollection} from one or more Avro files, use + * {@link AvroIO.Read}, specifying {@link AvroIO.Read#from} to specify + * the path of the file(s) to read from (e.g., a local filename or + * filename pattern if running locally, or a Google Cloud Storage + * filename or filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link AvroIO.Read#named} to specify the name of the pipeline step. + * + *

It is required to specify {@link AvroIO.Read#withSchema}. To + * read specific records, such as Avro-generated classes, provide an + * Avro-generated class type. To read GenericRecords, provide either + * an org.apache.avro.Schema or a schema in a JSON-encoded string form. + * An exception will be thrown if a record doesn't match the specified + * schema. + * + *

For example: + *

 {@code
+ * Pipeline p = ...;
+ *
+ * // A simple Read of a local file (only runs locally):
+ * PCollection records =
+ *     p.apply(AvroIO.Read.from("/path/to/file.avro")
+ *                        .withSchema(AvroAutoGenClass.class));
+ *
+ * // A Read from a GCS file (runs locally and via the Google Cloud
+ * // Dataflow service):
+ * Schema schema = new Schema.Parser().parse(new File(
+ *     "gs://my_bucket/path/to/schema.avsc"));
+ * PCollection records =
+ *     p.apply(AvroIO.Read.named("ReadFromAvro")
+ *                        .from("gs://my_bucket/path/to/records-*.avro")
+ *                        .withSchema(schema));
+ * } 
+ * + *

To write a {@link PCollection} to one or more Avro files, use + * {@link AvroIO.Write}, specifying {@link AvroIO.Write#to} to specify + * the path of the file to write to (e.g., a local filename or sharded + * filename pattern if running locally, or a Google Cloud Storage + * filename or sharded filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link AvroIO.Write#named} to specify the name of the pipeline step. + * + *

It is required to specify {@link AvroIO.Write#withSchema}. To + * write specific records, such as Avro-generated classes, provide an + * Avro-generated class type. To write GenericRecords, provide either + * an org.apache.avro.Schema or a schema in a JSON-encoded string form. + * An exception will be thrown if a record doesn't match the specified + * schema. + * + *

For example: + *

 {@code
+ * // A simple Write to a local file (only runs locally):
+ * PCollection records = ...;
+ * records.apply(AvroIO.Write.to("/path/to/file.avro")
+ *                           .withSchema(AvroAutoGenClass.class));
+ *
+ * // A Write to a sharded GCS file (runs locally and via the Google Cloud
+ * // Dataflow service):
+ * Schema schema = new Schema.Parser().parse(new File(
+ *     "gs://my_bucket/path/to/schema.avsc"));
+ * PCollection records = ...;
+ * records.apply(AvroIO.Write.named("WriteToAvro")
+ *                           .to("gs://my_bucket/path/to/numbers")
+ *                           .withSchema(schema)
+ *                           .withSuffix(".avro"));
+ * } 
+ */ +public class AvroIO { + + /** + * A root PTransform that reads from an Avro file (or multiple Avro + * files matching a pattern) and returns a PCollection containing + * the decoding of each record. + */ + public static class Read { + + /** + * Returns an AvroIO.Read PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(GenericRecord.class).named(name); + } + + /** + * Returns an AvroIO.Read PTransform that reads from the file(s) + * with the given name or pattern. This can be a local filename + * or filename pattern (if running locally), or a Google Cloud + * Storage filename or filename pattern of the form + * {@code "gs:///"}) (if running locally or via + * the Google Cloud Dataflow service). Standard + * Java Filesystem glob patterns ("*", "?", "[..]") are supported. + */ + public static Bound from(String filepattern) { + return new Bound<>(GenericRecord.class).from(filepattern); + } + + /** + * Returns an AvroIO.Read PTransform that reads Avro file(s) + * containing records whose type is the specified Avro-generated class. + * + * @param the type of the decoded elements, and the elements + * of the resulting PCollection + */ + public static Bound withSchema(Class type) { + return new Bound<>(type).withSchema(type); + } + + /** + * Returns an AvroIO.Read PTransform that reads Avro file(s) + * containing records of the specified schema. + */ + public static Bound withSchema(Schema schema) { + return new Bound<>(GenericRecord.class).withSchema(schema); + } + + /** + * Returns an AvroIO.Read PTransform that reads Avro file(s) + * containing records of the specified schema in a JSON-encoded + * string form. + */ + public static Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * A PTransform that reads from an Avro file (or multiple Avro + * files matching a pattern) and returns a bounded PCollection containing + * the decoding of each record. + * + * @param the type of each of the elements of the resulting + * PCollection + */ + public static class Bound + extends PTransform> { + private static final long serialVersionUID = 0; + + /** The filepattern to read from. */ + @Nullable final String filepattern; + /** The class type of the records. */ + final Class type; + /** The schema of the input file. */ + @Nullable final Schema schema; + + Bound(Class type) { + this(null, null, type, null); + } + + Bound(String name, String filepattern, Class type, Schema schema) { + super(name); + this.filepattern = filepattern; + this.type = type; + this.schema = schema; + } + + /** + * Returns a new AvroIO.Read PTransform that's like this one but + * with the given step name. Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filepattern, type, schema); + } + + /** + * Returns a new AvroIO.Read PTransform that's like this one but + * that reads from the file(s) with the given name or pattern. + * (See {@link AvroIO.Read#from} for a description of + * filepatterns.) Does not modify this object. + */ + public Bound from(String filepattern) { + return new Bound<>(name, filepattern, type, schema); + } + + /** + * Returns a new AvroIO.Read PTransform that's like this one but + * that reads Avro file(s) containing records whose type is the + * specified Avro-generated class. Does not modify this object. + * + * @param the type of the decoded elements, and the elements of + * the resulting PCollection + */ + public Bound withSchema(Class type) { + return new Bound<>(name, filepattern, type, ReflectData.get().getSchema(type)); + } + + /** + * Returns a new AvroIO.Read PTransform that's like this one but + * that reads Avro file(s) containing records of the specified schema. + * Does not modify this object. + */ + public Bound withSchema(Schema schema) { + return new Bound<>(name, filepattern, GenericRecord.class, schema); + } + + /** + * Returns a new AvroIO.Read PTransform that's like this one but + * that reads Avro file(s) containing records of the specified schema + * in a JSON-encoded string form. Does not modify this object. + */ + public Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + @Override + public PCollection apply(PInput input) { + if (filepattern == null) { + throw new IllegalStateException( + "need to set the filepattern of an AvroIO.Read transform"); + } + if (schema == null) { + throw new IllegalStateException( + "need to set the schema of an AvroIO.Read transform"); + } + + // Force the output's Coder to be what the read is using, and + // unchangeable later, to ensure that we read the input in the + // format specified by the Read transform. + return PCollection.createPrimitiveOutputInternal(new GlobalWindow()) + .setCoder(getDefaultOutputCoder()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return AvroCoder.of(type, schema); + } + + @Override + protected String getKindString() { return "AvroIO.Read"; } + + public String getFilepattern() { + return filepattern; + } + + public Schema getSchema() { + return schema; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateReadHelper(transform, context); + } + }); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A root PTransform that writes a PCollection to an Avro file (or + * multiple Avro files matching a sharding pattern). + */ + public static class Write { + + /** + * Returns an AvroIO.Write PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(GenericRecord.class).named(name); + } + + /** + * Returns an AvroIO.Write PTransform that writes to the file(s) + * with the given prefix. This can be a local filename + * (if running locally), or a Google Cloud Storage filename of + * the form {@code "gs:///"}) + * (if running locally or via the Google Cloud Dataflow service). + * + *

The files written will begin with this prefix, followed by + * a shard identifier (see {@link Bound#withNumShards}, and end + * in a common extension, if given by {@link Bound#withSuffix}. + */ + public static Bound to(String prefix) { + return new Bound<>(GenericRecord.class).to(prefix); + } + + /** + * Returns an AvroIO.Write PTransform that writes to the file(s) with the + * given filename suffix. + */ + public static Bound withSuffix(String filenameSuffix) { + return new Bound<>(GenericRecord.class).withSuffix(filenameSuffix); + } + + /** + * Returns an AvroIO.Write PTransform that uses the provided shard count. + * + *

Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + */ + public static Bound withNumShards(int numShards) { + return new Bound<>(GenericRecord.class).withNumShards(numShards); + } + + /** + * Returns an AvroIO.Write PTransform that uses the given shard name + * template. + * + * See {@link ShardNameTemplate} for a description of shard templates. + */ + public static Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(GenericRecord.class).withShardNameTemplate(shardTemplate); + } + + /** + * Returns an AvroIO.Write PTransform that forces a single file as + * output. + * + *

Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + */ + public static Bound withoutSharding() { + return new Bound<>(GenericRecord.class).withoutSharding(); + } + + /** + * Returns an AvroIO.Write PTransform that writes Avro file(s) + * containing records whose type is the specified Avro-generated class. + * + * @param the type of the elements of the input PCollection + */ + public static Bound withSchema(Class type) { + return new Bound<>(type).withSchema(type); + } + + /** + * Returns an AvroIO.Write PTransform that writes Avro file(s) + * containing records of the specified schema. + */ + public static Bound withSchema(Schema schema) { + return new Bound<>(GenericRecord.class).withSchema(schema); + } + + /** + * Returns an AvroIO.Write PTransform that writes Avro file(s) + * containing records of the specified schema in a JSON-encoded + * string form. + */ + public static Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * A PTransform that writes a bounded PCollection to an Avro file (or + * multiple Avro files matching a sharding pattern). + * + * @param the type of each of the elements of the input PCollection + */ + public static class Bound + extends PTransform, PDone> { + private static final long serialVersionUID = 0; + + /** The filename to write to. */ + @Nullable final String filenamePrefix; + /** Suffix to use for each filename. */ + final String filenameSuffix; + /** Requested number of shards. 0 for automatic. */ + final int numShards; + /** Shard template string. */ + final String shardTemplate; + /** The class type of the records. */ + final Class type; + /** The schema of the output file. */ + @Nullable final Schema schema; + + Bound(Class type) { + this(null, null, "", 0, ShardNameTemplate.INDEX_OF_MAX, type, null); + } + + Bound(String name, String filenamePrefix, String filenameSuffix, + int numShards, String shardTemplate, + Class type, Schema schema) { + super(name); + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.type = type; + this.schema = schema; + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * with the given step name. Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, + type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that writes to the file(s) with the given filename prefix. + * + *

See {@link Write#to(String) Write.to(String)} for more information. + * + *

Does not modify this object. + */ + public Bound to(String filenamePrefix) { + validateOutputComponent(filenamePrefix); + return new Bound<>(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, + type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that writes to the file(s) with the given filename suffix. + * + *

Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withSuffix(String filenameSuffix) { + validateOutputComponent(filenameSuffix); + return new Bound<>(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, + type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that uses the provided shard count. + * + *

Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + *

Does not modify this object. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + * @see ShardNameTemplate + */ + public Bound withNumShards(int numShards) { + Preconditions.checkArgument(numShards >= 0); + return new Bound<>(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, + type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that uses the given shard name template. + * + *

Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, + type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that forces a single file as output. + * + *

This is a shortcut for + * {@code .withNumShards(1).withShardNameTemplate("")} + * + *

Does not modify this object. + */ + public Bound withoutSharding() { + return new Bound<>(name, filenamePrefix, filenameSuffix, 1, "", type, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that writes to Avro file(s) containing records whose type is the + * specified Avro-generated class. Does not modify this object. + * + * @param the type of the elements of the input PCollection + */ + public Bound withSchema(Class type) { + return new Bound<>(name, filenamePrefix, filenameSuffix, + numShards, shardTemplate, + type, ReflectData.get().getSchema(type)); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that writes to Avro file(s) containing records of the specified + * schema. Does not modify this object. + */ + public Bound withSchema(Schema schema) { + return new Bound<>(name, filenamePrefix, filenameSuffix, + numShards, shardTemplate, + GenericRecord.class, schema); + } + + /** + * Returns a new AvroIO.Write PTransform that's like this one but + * that writes to Avro file(s) containing records of the specified + * schema in a JSON-encoded string form. Does not modify this object. + */ + public Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + @Override + public PDone apply(PCollection input) { + if (filenamePrefix == null) { + throw new IllegalStateException( + "need to set the filename prefix of an AvroIO.Write transform"); + } + if (schema == null) { + throw new IllegalStateException( + "need to set the schema of an AvroIO.Write transform"); + } + + return new PDone(); + } + + /** + * Returns the current shard name template string. + */ + public String getShardNameTemplate() { + return shardTemplate; + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + @Override + protected String getKindString() { return "AvroIO.Write"; } + + public String getFilenamePrefix() { + return filenamePrefix; + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public Class getType() { + return type; + } + + public Schema getSchema() { + return schema; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateWriteHelper(transform, context); + } + }); + } + } + } + + // Pattern which matches old-style shard output patterns, which are now + // disallowed. + private static final Pattern SHARD_OUTPUT_PATTERN = + Pattern.compile("@([0-9]+|\\*)"); + + private static void validateOutputComponent(String partialFilePattern) { + Preconditions.checkArgument( + !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(), + "Output name components are not allowed to contain @* or @N patterns: " + + partialFilePattern); + } + + ///////////////////////////////////////////////////////////////////////////// + + private static void evaluateReadHelper( + Read.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + AvroSource source = new AvroSource<>( + transform.filepattern, null, null, WindowedValue.getValueOnlyCoder( + transform.getDefaultOutputCoder())); + List> elems = readElemsFromSource(source); + List> output = new ArrayList<>(); + for (WindowedValue elem : elems) { + output.add(ValueWithMetadata.of(elem)); + } + context.setPCollectionValuesWithMetadata(transform.getOutput(), output); + } + + private static void evaluateWriteHelper( + Write.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + List> elems = context.getPCollectionWindowedValues(transform.getInput()); + int numShards = transform.numShards; + if (numShards < 1) { + // System gets to choose. For direct mode, choose 1. + numShards = 1; + } + AvroSink writer = new AvroSink<>(transform.filenamePrefix, transform.shardTemplate, + transform.filenameSuffix, numShards, + WindowedValue.getValueOnlyCoder( + AvroCoder.of(transform.type, transform.schema))); + try (Sink.SinkWriter> sink = writer.writer()) { + for (WindowedValue elem : elems) { + sink.add(elem); + } + } catch (IOException exn) { + throw new RuntimeException( + "unable to write to output file \"" + transform.filenamePrefix + "\"", + exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java new file mode 100644 index 000000000000..2fffe4de2c45 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java @@ -0,0 +1,937 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.api.client.json.JsonFactory; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.worker.BigQuerySource; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.BigQueryTableInserter; +import com.google.cloud.dataflow.sdk.util.CloudSourceUtils; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Transformations for reading and writing + * BigQuery tables. + *

Table References

+ * A fully-qualified BigQuery table name consists of three components: + *
    + *
  • {@code projectId}: the Cloud project id (defaults to + * {@link GcpOptions#getProject()}). + *
  • {@code datasetId}: the BigQuery dataset id, unique within a project. + *
  • {@code tableId}: a table id, unique within a dataset. + *
+ *

+ * BigQuery table references are stored as a {@link TableReference}, which comes + * from the BigQuery Java Client API. + * Tables can be referred to as Strings, with or without the {@code projectId}. + * A helper function is provided ({@link BigQueryIO#parseTableSpec(String)}), + * which parses the following string forms into a {@link TableReference}: + *

    + *
  • [{@code project_id}]:[{@code dataset_id}].[{@code table_id}] + *
  • [{@code dataset_id}].[{@code table_id}] + *
+ *

Reading

+ * To read from a BigQuery table, apply a {@link BigQueryIO.Read} transformation. + * This produces a {@code PCollection} as output: + *
{@code
+ * PCollection shakespeare = pipeline.apply(
+ *     BigQueryIO.Read
+ *         .named("Read")
+ *         .from("clouddataflow-readonly:samples.weather_stations");
+ * }
+ *

Writing

+ * To write to a BigQuery table, apply a {@link BigQueryIO.Write} transformation. + * This consumes a {@code PCollection} as input. + *

+ *

{@code
+ * PCollection quotes = ...
+ *
+ * List fields = new ArrayList<>();
+ * fields.add(new TableFieldSchema().setName("source").setType("STRING"));
+ * fields.add(new TableFieldSchema().setName("quote").setType("STRING"));
+ * TableSchema schema = new TableSchema().setFields(fields);
+ *
+ * quotes.apply(BigQueryIO.Write
+ *     .named("Write")
+ *     .to("my-project:output.output_table")
+ *     .withSchema(schema)
+ *     .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
+ * }
+ *

+ * See {@link BigQueryIO.Write} for details on how to specify if a write should + * append to an existing table, replace the table, or verify that the table is + * empty. + * + * @see TableRow + */ +public class BigQueryIO { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryIO.class); + + /** + * Singleton instance of the JSON factory used to read and write JSON + * formatted rows. + */ + private static final JsonFactory JSON_FACTORY = Transport.getJsonFactory(); + + /** + * Project IDs must contain 6-63 lowercase letters, digits, or dashes. + * IDs must start with a letter and may not end with a dash. + * This regex isn't exact - this allows for patterns that would be rejected by + * the service, but this is sufficient for basic parsing of table references. + */ + private static final String PROJECT_ID_REGEXP = + "[a-z][-a-z0-9:.]{4,61}[a-z0-9]"; + + /** + * Regular expression which matches Dataset IDs. + */ + private static final String DATASET_REGEXP = "[-\\w.]{1,1024}"; + + /** + * Regular expression which matches Table IDs. + */ + private static final String TABLE_REGEXP = "[-\\w$@]{1,1024}"; + + /** + * Matches table specifications in the form + * "[project_id]:[dataset_id].[table_id]" or "[dataset_id].[table_id]". + */ + private static final String DATASET_TABLE_REGEXP = String.format( + "((?%s):)?(?%s)\\.(?%s)", + PROJECT_ID_REGEXP, DATASET_REGEXP, TABLE_REGEXP); + + private static final Pattern TABLE_SPEC = + Pattern.compile(DATASET_TABLE_REGEXP); + + /** + * Parse a table specification in the form + * "[project_id]:[dataset_id].[table_id]" or "[dataset_id].[table_id]". + *

+ * If the project id is omitted, the default project id is used. + */ + public static TableReference parseTableSpec(String tableSpec) { + Matcher match = TABLE_SPEC.matcher(tableSpec); + if (!match.matches()) { + throw new IllegalArgumentException( + "Table reference is not in [project_id]:[dataset_id].[table_id] " + + "format: " + tableSpec); + } + + TableReference ref = new TableReference(); + ref.setProjectId(match.group("PROJECT")); + + return ref + .setDatasetId(match.group("DATASET")) + .setTableId(match.group("TABLE")); + } + + /** + * Returns a canonical string representation of the TableReference. + */ + public static String toTableSpec(TableReference ref) { + StringBuilder sb = new StringBuilder(); + if (ref.getProjectId() != null) { + sb.append(ref.getProjectId()); + sb.append(":"); + } + + sb.append(ref.getDatasetId()) + .append('.') + .append(ref.getTableId()); + return sb.toString(); + } + + /** + * A PTransform that reads from a BigQuery table and returns a + * {@code PCollection} containing each of the rows of the table. + *

+ * Each TableRow record contains values indexed by column name. Here is a + * sample processing function which processes a "line" column from rows: + *


+   * static class ExtractWordsFn extends DoFn{@literal } {
+   *   {@literal @}Override
+   *   public void processElement(ProcessContext c) {
+   *     // Get the "line" field of the TableRow object, split it into words, and emit them.
+   *     TableRow row = c.element();
+   *     String[] words = row.get("line").toString().split("[^a-zA-Z']+");
+   *     for (String word : words) {
+   *       if (!word.isEmpty()) {
+   *         c.output(word);
+   *       }
+   *     }
+   *   }
+   * }
+   * 
+ */ + public static class Read { + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * Reads a BigQuery table specified as + * "[project_id]:[dataset_id].[table_id]" or "[dataset_id].[table_id]" for + * tables within the current project. + */ + public static Bound from(String tableSpec) { + return new Bound().from(tableSpec); + } + + /** + * Reads a BigQuery table specified as a TableReference object. + */ + public static Bound from(TableReference table) { + return new Bound().from(table); + } + + /** + * Disables BigQuery table validation which is enabled by default. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * A PTransform that reads from a BigQuery table and returns a bounded + * {@code PCollection}. + */ + public static class Bound + extends PTransform> { + TableReference table; + final boolean validate; + + Bound() { + this.validate = true; + } + + Bound(String name, TableReference reference, boolean validate) { + super(name); + this.table = reference; + this.validate = validate; + } + + /** + * Sets the name associated with this transformation. + */ + public Bound named(String name) { + return new Bound(name, table, validate); + } + + /** + * Sets the table specification. + *

+ * Refer to {@link #parseTableSpec(String)} for the specification format. + */ + public Bound from(String tableSpec) { + return from(parseTableSpec(tableSpec)); + } + + /** + * Sets the table specification. + */ + public Bound from(TableReference table) { + return new Bound(name, table, validate); + } + + /** + * Disable table validation. + */ + public Bound withoutValidation() { + return new Bound(name, table, false); + } + + @Override + public PCollection apply(PInput input) { + if (table == null) { + throw new IllegalStateException( + "must set the table reference of a BigQueryIO.Read transform"); + } + return PCollection.createPrimitiveOutputInternal( + new GlobalWindow()) + // Force the output's Coder to be what the read is using, and + // unchangeable later, to ensure that we read the input in the + // format specified by the Read transform. + .setCoder(TableRowJsonCoder.of()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return TableRowJsonCoder.of(); + } + + @Override + protected String getKindString() { return "BigQueryIO.Read"; } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateReadHelper(transform, context); + } + }); + } + + /** + * Returns the table to write. + */ + public TableReference getTable() { + return table; + } + + /** + * Returns true if table validation is enabled. + */ + public boolean getValidate() { + return validate; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A PTransform that writes a {@code PCollection} containing rows + * to a BigQuery table. + *

+ * By default, tables will be created if they do not exist, which + * corresponds to a {@code CreateDisposition.CREATE_IF_NEEDED} disposition + * which matches the default of BigQuery's Jobs API. A schema must be + * provided (via {@link Write#withSchema}), or else the transform may fail + * at runtime with an {@link java.lang.IllegalArgumentException}. + *

+ * By default, writes require an empty table, which corresponds to + * a {@code WriteDisposition.WRITE_EMPTY} disposition which matches the + * default of BigQuery's Jobs API. + *

+ * Here is a sample transform which produces TableRow values containing + * "word" and "count" columns: + *


+   * static class FormatCountsFn extends DoFnP{@literal , TableRow>} {
+   *   {@literal @}Override
+   *   public void processElement(ProcessContext c) {
+   *     TableRow row = new TableRow()
+   *         .set("word", c.element().getKey())
+   *         .set("count", c.element().getValue().intValue());
+   *     c.output(row);
+   *   }
+   * }
+   * 
+ */ + public static class Write { + + /** + * An enumeration type for the BigQuery create disposition strings publicly + * documented as {@code CREATE_NEVER}, and {@code CREATE_IF_NEEDED}. + */ + public enum CreateDisposition { + /** + * Specifics that tables should not be created. + *

+ * If the output table does not exist, the write fails. + */ + CREATE_NEVER, + + /** + * Specifies that tables should be created if needed. This is the default + * behavior. + *

+ * Requires that a table schema is provided via {@link Write#withSchema}. + * This precondition is checked before starting a job. The schema is + * not required to match an existing table's schema. + *

+ * When this transformation is executed, if the output table does not + * exist, the table is created from the provided schema. Note that even if + * the table exists, it may be recreated if necessary when paired with a + * {@link WriteDisposition#WRITE_TRUNCATE}. + */ + CREATE_IF_NEEDED + } + + /** + * An enumeration type for the BigQuery write disposition strings publicly + * documented as {@code WRITE_TRUNCATE}, {@code WRITE_APPEND}, and + * {@code WRITE_EMPTY}. + */ + public enum WriteDisposition { + /** + * Specifies that write should replace a table. + *

+ * The replacement may occur in multiple steps - for instance by first + * removing the existing table, then creating a replacement, then filling + * it in. This is not an atomic operation, and external programs may + * see the table in any of these intermediate steps. + */ + WRITE_TRUNCATE, + + /** + * Specifies that rows may be appended to an existing table. + */ + WRITE_APPEND, + + /** + * Specifies that the output table must be empty. This is the default + * behavior. + *

+ * If the output table is not empty, the write fails at runtime. + *

+ * This check may occur long before data is written, and does not + * guarantee exclusive access to the table. If two programs are run + * concurrently, each specifying the same output table and + * a {@link WriteDisposition} of {@code WRITE_EMPTY}, it is possible + * for both to succeed. + */ + WRITE_EMPTY + } + + /** + * Sets the name associated with this transformation. + */ + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * Creates a write transformation for the given table specification. + *

+ * Refer to {@link #parseTableSpec(String)} for the specification format. + */ + public static Bound to(String tableSpec) { + return new Bound().to(tableSpec); + } + + /** Creates a write transformation for the given table. */ + public static Bound to(TableReference table) { + return new Bound().to(table); + } + + /** + * Specifies a table schema to use in table creation. + *

+ * The schema is required only if writing to a table which does not already + * exist, and {@link BigQueryIO.Write.CreateDisposition} is set to + * {@code CREATE_IF_NEEDED}. + */ + public static Bound withSchema(TableSchema schema) { + return new Bound().withSchema(schema); + } + + /** Specifies options for creating the table. */ + public static Bound withCreateDisposition(CreateDisposition disposition) { + return new Bound().withCreateDisposition(disposition); + } + + /** Specifies options for writing to the table. */ + public static Bound withWriteDisposition(WriteDisposition disposition) { + return new Bound().withWriteDisposition(disposition); + } + + /** + * Disables BigQuery table validation which is enabled by default. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * A PTransform that can write either a bounded or unbounded + * {@code PCollection}s to a BigQuery table. + */ + public static class Bound + extends PTransform, PDone> { + final TableReference table; + + // Table schema. The schema is required only if the table does not exist. + final TableSchema schema; + + // Options for creating the table. Valid values are CREATE_IF_NEEDED and + // CREATE_NEVER. + final CreateDisposition createDisposition; + + // Options for writing to the table. Valid values are WRITE_TRUNCATE, + // WRITE_APPEND and WRITE_EMPTY. + final WriteDisposition writeDisposition; + + // An option to indicate if table validation is desired. Default is true. + final boolean validate; + + public Bound() { + this.table = null; + this.schema = null; + this.createDisposition = CreateDisposition.CREATE_IF_NEEDED; + this.writeDisposition = WriteDisposition.WRITE_EMPTY; + this.validate = true; + } + + Bound(String name, TableReference ref, TableSchema schema, + CreateDisposition createDisposition, + WriteDisposition writeDisposition, + boolean validate) { + super(name); + this.table = ref; + this.schema = schema; + this.createDisposition = createDisposition; + this.writeDisposition = writeDisposition; + this.validate = validate; + } + + /** + * Sets the name associated with this transformation. + */ + public Bound named(String name) { + return new Bound(name, table, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Specifies the table specification. + *

+ * Refer to {@link #parseTableSpec(String)} for the specification format. + */ + public Bound to(String tableSpec) { + return to(parseTableSpec(tableSpec)); + } + + /** + * Specifies the table to be written to. + */ + public Bound to(TableReference table) { + return new Bound(name, table, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Specifies the table schema, used if the table is created. + */ + public Bound withSchema(TableSchema schema) { + return new Bound(name, table, schema, createDisposition, + writeDisposition, validate); + } + + /** Specifies options for creating the table. */ + public Bound withCreateDisposition(CreateDisposition createDisposition) { + return new Bound(name, table, schema, createDisposition, + writeDisposition, validate); + } + + /** Specifies options for writing the table. */ + public Bound withWriteDisposition(WriteDisposition writeDisposition) { + return new Bound(name, table, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Disable table validation. + */ + public Bound withoutValidation() { + return new Bound(name, table, schema, createDisposition, writeDisposition, false); + } + + @Override + public PDone apply(PCollection input) { + if (table == null) { + throw new IllegalStateException( + "must set the table reference of a BigQueryIO.Write transform"); + } + + if (createDisposition == CreateDisposition.CREATE_IF_NEEDED && + schema == null) { + throw new IllegalArgumentException( + "CreateDisposition is CREATE_IF_NEEDED, " + + "however no schema was provided."); + } + + // In streaming, BigQuery write is taken care of by StreamWithDeDup transform. + BigQueryOptions options = getPipeline().getOptions().as(BigQueryOptions.class); + if (options.isStreaming()) { + return input.apply(new StreamWithDeDup(table, schema)); + } + + return new PDone(); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + @Override + protected String getKindString() { return "BigQueryIO.Write"; } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateWriteHelper(transform, context); + } + }); + } + + /** Returns the create disposition. */ + public CreateDisposition getCreateDisposition() { + return createDisposition; + } + + /** Returns the write disposition. */ + public WriteDisposition getWriteDisposition() { + return writeDisposition; + } + + /** Returns the table schema. */ + public TableSchema getSchema() { + return schema; + } + + /** Returns the table reference. */ + public TableReference getTable() { + return table; + } + + /** Returns true if table validation is enabled. */ + public boolean getValidate() { + return validate; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Implementation of DoFn to perform streaming BigQuery write. + */ + private static class StreamingWriteFn extends DoFn>, Void> + implements DoFn.RequiresKeyedState { + + /** + * Class to accumulate BigQuery row data as a list of String. + * DoFn implementation must be Serializable, but BigQuery classes, + * such as TableRow are not. Therefore, convert into JSON String + * for accumulation. + */ + private static class JsonTableRows implements Iterable, Serializable { + + /** The list where BigQuery row data is accumulated. */ + private final List jsonRows = new ArrayList<>(); + + /** Iterator of JsonTableRows converts the row in String to TableRow. */ + static class JsonTableRowIterator implements Iterator { + + private final Iterator iteratorInternal; + + /** Constructor. */ + JsonTableRowIterator(List jsonRowList) { + iteratorInternal = jsonRowList.iterator(); + } + + @Override + public boolean hasNext() { + return iteratorInternal.hasNext(); + } + + @Override + public TableRow next() { + try { + // Converts the String back into TableRow. + return JSON_FACTORY.fromString(iteratorInternal.next(), TableRow.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void remove() { + iteratorInternal.remove(); + } + } + + /** Returns the iterator. */ + @Override + public Iterator iterator() { + return new JsonTableRowIterator(jsonRows); + } + + /** Adds a BigQuery TableRow. */ + void add(TableRow row) { + try { + // Converts into JSON format. + jsonRows.add(JSON_FACTORY.toString(row)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + /** TableReference in JSON. Use String to make the class Serializable. */ + private final String jsonTableReference; + + /** TableSchema in JSON. Use String to make the class Serializable. */ + private final String jsonTableSchema; + + /** JsonTableRows to accumulate BigQuery rows. */ + private JsonTableRows jsonTableRows; + + /** The list of unique ids for each BigQuery table row. */ + private List uniqueIdsForTableRows; + + /** The list of tables created so far, so we don't try the creation + each time. */ + private static ThreadLocal> createdTables = + new ThreadLocal>() { + @Override protected HashSet initialValue() { + return new HashSet<>(); + } + }; + + /** Constructor. */ + StreamingWriteFn(TableReference table, TableSchema schema) { + try { + jsonTableReference = JSON_FACTORY.toString(table); + jsonTableSchema = JSON_FACTORY.toString(schema); + } catch (IOException e) { + throw new RuntimeException("Cannot initialize BigQuery streaming writer.", e); + } + } + + /** Prepares a target BigQuery table. */ + @Override + public void startBundle(Context context) { + jsonTableRows = new JsonTableRows(); + uniqueIdsForTableRows = new ArrayList<>(); + BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class); + Bigquery client = Transport.newBigQueryClient(options).build(); + + // TODO: Support table sharding and the better place to initialize + // BigQuery table. + HashSet tables = createdTables.get(); + if (!tables.contains(jsonTableSchema)) { + try { + TableSchema tableSchema = JSON_FACTORY.fromString( + jsonTableSchema, TableSchema.class); + TableReference tableReference = JSON_FACTORY.fromString( + jsonTableReference, TableReference.class); + + + BigQueryTableInserter inserter = new BigQueryTableInserter(client, tableReference); + inserter.tryCreateTable(tableSchema); + tables.add(jsonTableSchema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + /** Accumulates the input into JsonTableRows and uniqueIdsForTableRows. */ + @Override + public void processElement(ProcessContext context) { + KV> kv = context.element(); + TableRow tableRow = kv.getValue().getValue(); + uniqueIdsForTableRows.add(kv.getValue().getKey()); + jsonTableRows.add(tableRow); + } + + /** Writes the accumulated rows into BigQuery with streaming API. */ + @Override + public void finishBundle(Context context) { + BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class); + Bigquery client = Transport.newBigQueryClient(options).build(); + + try { + TableReference tableReference = JSON_FACTORY.fromString( + jsonTableReference, TableReference.class); + + BigQueryTableInserter inserter = new BigQueryTableInserter(client, tableReference); + inserter.insertAll(jsonTableRows.iterator(), uniqueIdsForTableRows.iterator()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Fn that tags each table row with a unique id. + * To avoid calling UUID.randomUUID() for each element, which can be costly, + * a randomUUID is generated only once per bucket of data. The actual unique + * id is created by concatenating this randomUUID with a sequential number. + */ + private static class TagWithUniqueIds extends DoFn>> { + private transient String randomUUID; + private transient AtomicLong sequenceNo; + + @Override + public void startBundle(Context context) { + randomUUID = UUID.randomUUID().toString(); + sequenceNo = new AtomicLong(); + } + + /** Tag the input with a unique id. */ + @Override + public void processElement(ProcessContext context) { + String uniqueId = randomUUID + Long.toString(sequenceNo.getAndIncrement()); + ThreadLocalRandom randomGenerator = ThreadLocalRandom.current(); + // We output on keys 0-50 to ensure that there's enough batching for + // BigQuery. + context.output(KV.of(randomGenerator.nextInt(0, 50), + KV.of(uniqueId, context.element()))); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * PTransform that performs streaming BigQuery write. To increase consistency, + * it leverages BigQuery best effort de-dup mechanism. + */ + private static class StreamWithDeDup + extends PTransform, PDone> { + + private final TableReference tableReference; + private final TableSchema tableSchema; + + /** Constructor. */ + StreamWithDeDup(TableReference tableReference, TableSchema tableSchema) { + this.tableReference = tableReference; + this.tableSchema = tableSchema; + } + + @Override protected Coder getDefaultOutputCoder() { return VoidCoder.of(); } + + @Override + public PDone apply(PCollection in) { + // A naive implementation would be to simply stream data directly to BigQuery. + // However, this could occassionally lead to duplicated data, e.g., when + // a VM that runs this code is restarted and the code is re-run. + + // The above risk is mitigated in this implementation by relying on + // BigQuery built-in best effort de-dup mechanism. + + // To use this mechanism, each input TableRow is tagged with a generated + // unique id, which is then passed to BigQuery and used to ignore duplicates. + + PCollection>> tagged = + in.apply(ParDo.of(new TagWithUniqueIds())); + + // To prevent having the same TableRow processed more than once with regenerated + // different unique ids, this implementation relies on "checkpointing" which is + // achieved as a side effect of having StreamingWriteFn implement RequiresKeyedState. + tagged.apply(ParDo.of(new StreamingWriteFn(tableReference, tableSchema))); + + // Note that the implementation to return PDone here breaks the + // implicit assumption about the job execution order. If a user + // implements a PTransform that takes PDone returned here as its + // input, the transform may not necessarily be executed after + // the BigQueryIO.Write. + + return new PDone(); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Direct mode read evaluator. + *

+ * This loads the entire table into an in-memory PCollection. + */ + private static void evaluateReadHelper( + Read.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + BigQueryOptions options = context.getPipelineOptions(); + Bigquery client = Transport.newBigQueryClient(options).build(); + TableReference ref = transform.table; + if (ref.getProjectId() == null) { + ref.setProjectId(options.getProject()); + } + + LOG.info("Reading from BigQuery table {}", toTableSpec(ref)); + List elems = CloudSourceUtils.readElemsFromSource(new BigQuerySource(client, ref)); + LOG.info("Number of records read from BigQuery: {}", elems.size()); + context.setPCollection(transform.getOutput(), elems); + } + + /** + * Direct mode write evaluator. + *

+ * This writes the entire table in a single BigQuery request. + * The table will be created if necessary. + */ + private static void evaluateWriteHelper( + Write.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + BigQueryOptions options = context.getPipelineOptions(); + Bigquery client = Transport.newBigQueryClient(options).build(); + TableReference ref = transform.table; + if (ref.getProjectId() == null) { + ref.setProjectId(options.getProject()); + } + + LOG.info("Writing to BigQuery table {}", toTableSpec(ref)); + + try { + BigQueryTableInserter inserter = new BigQueryTableInserter(client, ref); + + inserter.getOrCreateTable(transform.writeDisposition, + transform.createDisposition, transform.schema); + + List tableRows = context.getPCollection(transform.getInput()); + inserter.insertAll(tableRows.iterator()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java new file mode 100644 index 000000000000..9c7fc0a1c5b3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java @@ -0,0 +1,603 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.services.datastore.DatastoreV1.BeginTransactionRequest; +import com.google.api.services.datastore.DatastoreV1.BeginTransactionResponse; +import com.google.api.services.datastore.DatastoreV1.CommitRequest; +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.client.Datastore; +import com.google.api.services.datastore.client.DatastoreException; +import com.google.api.services.datastore.client.DatastoreFactory; +import com.google.api.services.datastore.client.DatastoreHelper; +import com.google.api.services.datastore.client.DatastoreOptions; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.coders.EntityCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.Credentials; +import com.google.cloud.dataflow.sdk.util.RetryHttpRequestInitializer; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PDone; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; + +/** + * Transforms for reading and writing + * Google Cloud Datastore + * entities. + * + *

The DatastoreIO class provides an experimental API to Read and Write a + * {@link PCollection} of Datastore Entity. Currently the class supports + * read operations on both the DirectPipelineRunner and DataflowPipelineRunner, + * and write operations on the DirectPipelineRunner. This API is subject to + * change, and currently requires an authentication workaround described below. + * + *

Datastore is a fully managed NoSQL data storage service. + * An Entity is an object in Datastore, analogous to the a row in traditional + * database table. DatastoreIO supports Read/Write from/to Datastore within + * Dataflow SDK service. + * + *

To use DatastoreIO, users must set up the environment and use gcloud + * to get credential for Datastore: + *

+ * $ export CLOUDSDK_EXTRA_SCOPES=https://www.googleapis.com/auth/datastore
+ * $ gcloud auth login
+ * 
+ * + *

Note that the environment variable CLOUDSDK_EXTRA_SCOPES must be set + * to the same value when executing a Datastore pipeline, as the local auth + * cache is keyed by the requested scopes. + * + *

To read a {@link PCollection} from a query to Datastore, use + * {@link DatastoreIO.Read}, specifying {@link DatastoreIO.Read#from} to specify + * dataset to read, the query to read from, and optionally + * {@link DatastoreIO.Read#named} and {@link DatastoreIO.Read#withHost} to specify + * the name of the pipeline step and the host of Datastore, respectively. + * For example: + * + *

 {@code
+ * // Read a query from Datastore
+ * PipelineOptions options =
+ *     CliPipelineOptionsFactory.create(PipelineOptions.class, args);
+ * Pipeline p = Pipeline.create(options);
+ * PCollection entities =
+ *     p.apply(DatastoreIO.Read
+ *             .named("Read Datastore")
+ *             .from(datasetId, query)
+ *             .withHost(host));
+ * p.run();
+ * } 
+ * + *

To write a {@link PCollection} to a datastore, use + * {@link DatastoreIO.Write}, specifying {@link DatastoreIO.Write#to} to specify + * the datastore to write to, and optionally {@link TextIO.Write#named} to specify + * the name of the pipeline step. For example: + * + *

 {@code
+ * // A simple Write to Datastore with DirectPipelineRunner (writing is not
+ * // yet implemented for other runners):
+ * PCollection entities = ...;
+ * lines.apply(DatastoreIO.Write.to("Write entities", datastore));
+ * p.run();
+ *
+ * } 
+ */ + +public class DatastoreIO { + + private static final Logger LOG = LoggerFactory.getLogger(DatastoreIO.class); + private static final String DEFAULT_HOST = "https://www.googleapis.com"; + + /** + * A PTransform that reads from a Datastore query and returns a + * {@code PCollection} containing each of the rows of the table. + */ + public static class Read { + + /** + * Returns a DatastoreIO.Read PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound(DEFAULT_HOST).named(name); + } + + /** + * Reads entities retrieved from the dataset and a given query. + */ + public static Bound from(String datasetId, Query query) { + return new Bound(DEFAULT_HOST).from(datasetId, query); + } + + /** + * Returns a DatastoreIO.Read PTransform with specified host. + */ + public static Bound withHost(String host) { + return new Bound(host); + } + + /** + * A PTransform that reads from a Datastore query and returns a bounded + * {@code PCollection}. + */ + public static class Bound extends PTransform> { + String host; + String datasetId; + Query query; + + /** + * Returns a DatastoreIO.Bound object with given query. + * Sets the name, Datastore host, datasetId, query associated + * with this PTransform, and options for this Pipeline. + */ + Bound(String name, String host, String datasetId, Query query) { + super(name); + this.host = host; + this.datasetId = datasetId; + this.query = query; + } + + /** + * Returns a DatastoreIO.Read PTransform with host set up. + */ + Bound(String host) { + this.host = host; + } + + /** + * Returns a new DatastoreIO.Read PTransform with the name + * associated with this transformation. + */ + public Bound named(String name) { + return new Bound(name, host, datasetId, query); + } + + /** + * Returns a new DatastoreIO.Read PTransform with datasetId, + * and query associated with this transformation, and options + * associated with this Pipleine. + */ + public Bound from(String datasetId, Query query) { + return new Bound(name, host, datasetId, query); + } + + /** + * Returns a new DatastoreIO.Read PTransform with the host + * specified. + */ + public Bound withHost(String host) { + return new Bound(name, host, datasetId, query); + } + + @Override + public PCollection apply(PBegin input) { + if (datasetId == null || query == null) { + throw new IllegalStateException( + "need to set datasetId, and query " + + "of a DatastoreIO.Read transform"); + } + + QueryOptions queryOptions = QueryOptions.create(host, datasetId, query); + PCollection output; + try { + DataflowPipelineOptions options = + getPipeline().getOptions().as(DataflowPipelineOptions.class); + PCollection queries = splitQueryOptions(queryOptions, options, input); + + output = queries.apply(ParDo.of(new ReadEntitiesFn())); + getCoderRegistry().registerCoder(Entity.class, EntityCoder.class); + } catch (DatastoreException e) { + LOG.warn("DatastoreException: error while doing Datastore query splitting.", e); + throw new RuntimeException("Error while splitting Datastore query."); + } + + return output; + } + } + } + + ///////////////////// Write Class ///////////////////////////////// + /** + * A PTransform that writes a {@code PCollection} containing + * entities to a Datastore kind. + * + * Current version only supports Write operation running on + * DirectPipelineRunner. If Write is used on DataflowPipelineRunner, + * it throws UnsupportedOperationException and won't continue on the + * operation. + * + */ + public static class Write { + /** + * Returns a DatastoreIO.Write PTransform with the name + * associated with this PTransform. + */ + public static Bound named(String name) { + return new Bound(DEFAULT_HOST).named(name); + } + + /** + * Returns a DatastoreIO.Write PTransform with given datasetId. + */ + public static Bound to(String datasetId) { + return new Bound(DEFAULT_HOST).to(datasetId); + } + + /** + * Returns a DatastoreIO.Write PTransform with specified host. + */ + public static Bound withHost(String host) { + return new Bound(host); + } + + /** + * A PTransform that writes a bounded {@code PCollection} + * to a Datastore. + */ + public static class Bound extends PTransform, PDone> { + String host; + String datasetId; + + /** + * Returns a DatastoreIO.Write PTransform with given host. + */ + Bound(String host) { + this.host = host; + } + + /** + * Returns a DatastoreIO.Write.Bound object. + * Sets the name, datastore agent, and kind associated + * with this transformation. + */ + Bound(String name, String host, String datasetId) { + super(name); + this.host = host; + this.datasetId = datasetId; + } + + /** + * Returns a DatastoreIO.Write PTransform with the name + * associated with this PTransform. + */ + public Bound named(String name) { + return new Bound(name, host, datasetId); + } + + /** + * Returns a DatastoreIO.Write PTransform with given datasetId. + */ + public Bound to(String datasetId) { + return new Bound(name, host, datasetId); + } + + /** + * Returns a new DatastoreIO.Write PTransform with specified host. + */ + public Bound withHost(String host) { + return new Bound(name, host, datasetId); + } + + @Override + public PDone apply(PCollection input) { + if (this.host == null || this.datasetId == null) { + throw new IllegalStateException( + "need to set Datastore host and dataasetId" + + "of a DatastoreIO.Write transform"); + } + + return new PDone(); + } + + @Override + protected String getKindString() { return "DatastoreIO.Write"; } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateWriteHelper(transform, context); + } + }); + } + } + } + + /////////////////////////////////////////////////////////////////// + + /** + * A DoFn that performs query request to Datastore and converts + * each QueryOptions into Entities. + */ + private static class ReadEntitiesFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Query query = c.element().getQuery(); + Datastore datastore = c.element().getWorkerDatastore( + c.getPipelineOptions().as(GcpOptions.class)); + DatastoreIterator entityIterator = new DatastoreIterator(query, datastore); + + while (entityIterator.hasNext()) { + c.output(entityIterator.next().getEntity()); + } + } + } + + /** + * A class that stores query and datastore setup environments + * (host and datasetId). + */ + @DefaultCoder(AvroCoder.class) + private static class QueryOptions { + // Query to read in byte array. + public byte[] byteQuery; + + // Datastore host to read from. + public String host; + + // Datastore dataset ID to read from. + public String datasetId; + + @SuppressWarnings("unused") + QueryOptions() {} + + /** + * Returns a QueryOption object without account and private key file + * (for supporting query on local Datastore). + * + * @param host the host of Datastore to connect + * @param datasetId the dataset ID of Datastore to query + * @param query the query to perform + */ + QueryOptions(String host, String datasetId, Query query) { + this.host = host; + this.datasetId = datasetId; + this.setQuery(query); + } + + /** + * Creates and returns a QueryOption object for query on local Datastore. + * + * @param host the host of Datastore to connect + * @param datasetId the dataset ID of Datastore to query + * @param query the query to perform + */ + public static QueryOptions create(String host, String datasetId, Query query) { + return new QueryOptions(host, datasetId, query); + } + + /** + * Sets up a query. + * Stores query in a byte array so that we can use AvroCoder to encode/decode + * QueryOptions. + * + * @param q the query to be addressed + */ + public void setQuery(Query q) { + this.byteQuery = q.toByteArray(); + } + + /** + * Returns query. + * + * @return query in this option. + */ + public Query getQuery() { + try { + return Query.parseFrom(this.byteQuery); + } catch (IOException e) { + LOG.warn("IOException: parsing query failed.", e); + throw new RuntimeException("Cannot parse query from byte array."); + } + } + + /** + * Returns the dataset ID. + * + * @return a dataset ID string for Datastore. + */ + public String getDatasetId() { + return this.datasetId; + } + + /** + * Returns a copy of QueryOptions from current options with given query. + * + * @param query a new query to be set + * @return A QueryOptions object for query + */ + public QueryOptions newQuery(Query query) { + return create(host, datasetId, query); + } + + /** + * Returns a Datastore object for connecting to Datastore on workers. + * This method will try to get worker credential from Credentials + * library and constructs a Datastore object which is set up and + * ready to communicate with Datastore. + * + * @return a Datastore object setup with host and dataset. + */ + public Datastore getWorkerDatastore(GcpOptions options) { + DatastoreOptions.Builder builder = new DatastoreOptions.Builder() + .host(this.host) + .dataset(this.datasetId) + .initializer(new RetryHttpRequestInitializer(null)); + + try { + Credential credential = Credentials.getWorkerCredential(options); + builder.credential(credential); + } catch (IOException e) { + LOG.warn("IOException: can't get credential for worker.", e); + throw new RuntimeException("Failed on getting credential for worker."); + } + return DatastoreFactory.get().create(builder.build()); + } + + /** + * Returns a Datastore object for connecting to Datastore for users. + * This method will use the passed in credentials and construct a Datastore + * object which is set up and ready to communicate with Datastore. + * + * @return a Datastore object setup with host and dataset. + */ + public Datastore getUserDatastore(GcpOptions options) { + DatastoreOptions.Builder builder = new DatastoreOptions.Builder() + .host(this.host) + .dataset(this.datasetId) + .initializer(new RetryHttpRequestInitializer(null)); + + Credential credential = options.getGcpCredential(); + if (credential != null) { + builder.credential(credential); + } + return DatastoreFactory.get().create(builder.build()); + } + } + + /** + * Returns a list of QueryOptions by splitting a QueryOptions into sub-queries. + * This method leverages the QuerySplitter in Datastore to split the + * query into sub-queries for further parallel query in Dataflow service. + * + * @return a PCollection of QueryOptions for split queries + */ + private static PCollection splitQueryOptions( + QueryOptions queryOptions, DataflowPipelineOptions options, + PBegin input) + throws DatastoreException { + Query query = queryOptions.getQuery(); + Datastore datastore = queryOptions.getUserDatastore(options); + + // Get splits from the QuerySplit interface. + List splitQueries = DatastoreHelper.getQuerySplitter() + .getSplits(query, options.getNumWorkers(), datastore); + + List> queryList = new LinkedList<>(); + for (Query q : splitQueries) { + PCollection newQuery = input + .apply(Create.of(queryOptions.newQuery(q))); + queryList.add(newQuery); + } + + // This is a workaround to allow for parallelism of a small collection. + return PCollectionList.of(queryList) + .apply(Flatten.create()); + } + + ///////////////////////////////////////////////////////////////////// + + /** + * Direct mode write evaluator. + * This writes the result to Datastore. + */ + private static void evaluateWriteHelper( + Write.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + LOG.info("Writing to Datastore"); + GcpOptions options = context.getPipelineOptions(); + Credential credential = options.getGcpCredential(); + Datastore datastore = DatastoreFactory.get().create( + new DatastoreOptions.Builder() + .host(transform.host) + .dataset(transform.datasetId) + .credential(credential) + .initializer(new RetryHttpRequestInitializer(null)) + .build()); + + List entityList = context.getPCollection(transform.getInput()); + + // Create a map to put entities with same ancestor for writing in a batch. + HashMap> map = new HashMap<>(); + for (Entity e : entityList) { + String keyOfAncestor = e.getKey().getPathElement(0).getKind() + + e.getKey().getPathElement(0).getName(); + List value = map.get(keyOfAncestor); + if (value == null) { + value = new ArrayList<>(); + } + value.add(e); + map.put(keyOfAncestor, value); + } + + // Walk over the map, and write entities bucket by bucket. + int count = 0; + for (String k : map.keySet()) { + List entitiesWithSameAncestor = map.get(k); + List toInsert = new ArrayList<>(); + for (Entity e : entitiesWithSameAncestor) { + toInsert.add(e); + // Note that Datastore has limit as 500 for a batch operation, + // so just flush to Datastore with every 500 entties. + if (toInsert.size() >= 500) { + writeBatch(toInsert, datastore); + toInsert.clear(); + } + } + writeBatch(toInsert, datastore); + count += entitiesWithSameAncestor.size(); + } + + LOG.info("Total number of entities written: {}", count); + } + + /** + * A function for batch writing to Datastore. + */ + private static void writeBatch(List listOfEntities, Datastore datastore) { + try { + BeginTransactionRequest.Builder treq = BeginTransactionRequest.newBuilder(); + BeginTransactionResponse tres = datastore.beginTransaction(treq.build()); + CommitRequest.Builder creq = CommitRequest.newBuilder(); + creq.setTransaction(tres.getTransaction()); + creq.getMutationBuilder().addAllInsertAutoId(listOfEntities); + datastore.commit(creq.build()); + } catch (DatastoreException e) { + LOG.warn("Error while doing datastore operation: {}", e); + throw new RuntimeException("Datastore exception", e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIterator.java new file mode 100644 index 000000000000..1b6d92e73c76 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIterator.java @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.api.services.datastore.DatastoreV1.EntityResult; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.DatastoreV1.QueryResultBatch; +import com.google.api.services.datastore.DatastoreV1.RunQueryRequest; +import com.google.api.services.datastore.DatastoreV1.RunQueryResponse; +import com.google.api.services.datastore.client.Datastore; +import com.google.api.services.datastore.client.DatastoreException; +import com.google.common.collect.AbstractIterator; + +import java.util.Iterator; + +/** + * An iterator over the records from a query of the datastore. + * + *

Usage: + *

{@code
+ *   // Need to pass query and datastore object.
+ *   DatastoreIterator iterator = new DatastoreIterator(query, datastore);
+ *   while (iterator.hasNext()) {
+ *     Entity e = iterator.next().getEntity();
+ *     ...
+ *   }
+ * }
+ */ +class DatastoreIterator extends AbstractIterator { + /** + * Query to select records. + */ + private Query.Builder query; + + /** + * Datastore to read from. + */ + private Datastore datastore; + + /** + * True if more results may be available. + */ + private boolean moreResults; + + /** + * Iterator over records. + */ + private Iterator entities; + + /** + * Current batch of query results. + */ + private QueryResultBatch currentBatch; + + /** + * Maximum number of results to request per query. + * + *

Must be set, or it may result in an I/O error when querying + * Cloud Datastore. + */ + private static final int QUERY_LIMIT = 5000; + + /** + * Returns a DatastoreIterator with query and Datastore object set. + * + * @param query the query to select records. + * @param datastore a datastore connection to use. + */ + public DatastoreIterator(Query query, Datastore datastore) { + this.query = query.toBuilder().clone(); + this.datastore = datastore; + this.query.setLimit(QUERY_LIMIT); + } + + /** + * Returns an iterator over the next batch of records for the query + * and updates the cursor to get the next batch as needed. + * Query has specified limit and offset from InputSplit. + */ + private Iterator getIteratorAndMoveCursor() + throws DatastoreException{ + if (this.currentBatch != null && this.currentBatch.hasEndCursor()) { + this.query.setStartCursor(this.currentBatch.getEndCursor()); + } + + RunQueryRequest request = RunQueryRequest.newBuilder() + .setQuery(this.query) + .build(); + RunQueryResponse response = this.datastore.runQuery(request); + + this.currentBatch = response.getBatch(); + + // MORE_RESULTS_AFTER_LIMIT is not implemented yet: + // https://groups.google.com/forum/#!topic/gcd-discuss/iNs6M1jA2Vw, so + // use result count to determine if more results might exist. + int numFetch = this.currentBatch.getEntityResultCount(); + moreResults = numFetch == QUERY_LIMIT; + + // May receive a batch of 0 results if the number of records is a multiple + // of the request limit. + if (numFetch == 0) { + return null; + } + + return this.currentBatch.getEntityResultList().iterator(); + } + + @Override + public EntityResult computeNext() { + try { + if (entities == null || (!entities.hasNext() && this.moreResults)) { + entities = getIteratorAndMoveCursor(); + } + + if (entities == null || !entities.hasNext()) { + return endOfData(); + } + + return entities.next(); + + } catch (DatastoreException e) { + throw new RuntimeException( + "Datastore error while iterating over entities", e); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java new file mode 100644 index 000000000000..b9f051484159 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java @@ -0,0 +1,331 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * [Whitelisting Required] Read and Write transforms for Pub/Sub streams. These transforms create + * and consume unbounded {@link com.google.cloud.dataflow.sdk.values.PCollection}s. + * + *

Important: PubsubIO is experimental. It is not supported by the + * {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner} and is only supported in the + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner} for users whitelisted in a + * streaming early access program and who enable + * {@link com.google.cloud.dataflow.sdk.options.StreamingOptions#setStreaming(boolean)}. + * + *

You should expect this class to change significantly in future versions of the SDK + * or be removed entirely. + */ +public class PubsubIO { + + /** + * Project IDs must contain 6-63 lowercase letters, digits, or dashes. + * IDs must start with a letter and may not end with a dash. + * This regex isn't exact - this allows for patterns that would be rejected by + * the service, but this is sufficient for basic parsing of table references. + */ + private static final Pattern PROJECT_ID_REGEXP = + Pattern.compile("[a-z][-a-z0-9:.]{4,61}[a-z0-9]"); + + private static final Pattern SUBSCRIPTION_REGEXP = + Pattern.compile("/subscriptions/([^/]+)/(.+)"); + + private static final Pattern TOPIC_REGEXP = + Pattern.compile("/topics/([^/]+)/(.+)"); + + private static final Pattern PUBSUB_NAME_REGEXP = + Pattern.compile("[a-z][-._a-z0-9]+[a-z0-9]"); + + private static final int PUBSUB_NAME_MAX_LENGTH = 255; + + private static final String SUBSCRIPTION_RANDOM_TEST_PREFIX = "_random/"; + private static final String TOPIC_DEV_NULL_TEST_NAME = "/topics/dev/null"; + + /** + * Utility class to validate topic and subscription names. + */ + public static class Validator { + public static void validateTopicName(String topic) { + if (topic.equals(TOPIC_DEV_NULL_TEST_NAME)) { + return; + } + Matcher match = TOPIC_REGEXP.matcher(topic); + if (!match.matches()) { + throw new IllegalArgumentException( + "Pubsub topic is not in /topics/project_id/topic_name format: " + + topic); + } + validateProjectName(match.group(1)); + validatePubsubName(match.group(2)); + } + + public static void validateSubscriptionName(String subscription) { + if (subscription.startsWith(SUBSCRIPTION_RANDOM_TEST_PREFIX)) { + return; + } + Matcher match = SUBSCRIPTION_REGEXP.matcher(subscription); + if (!match.matches()) { + throw new IllegalArgumentException( + "Pubsub subscription is not in /subscriptions/project_id/subscription_name format: " + + subscription); + } + validateProjectName(match.group(1)); + validatePubsubName(match.group(2)); + } + + private static void validateProjectName(String project) { + Matcher match = PROJECT_ID_REGEXP.matcher(project); + if (!match.matches()) { + throw new IllegalArgumentException( + "Illegal project name specified in Pubsub subscription: " + project); + } + } + + private static void validatePubsubName(String name) { + if (name.length() > PUBSUB_NAME_MAX_LENGTH) { + throw new IllegalArgumentException( + "Pubsub object name is longer than 255 characters: " + name); + } + + if (name.startsWith("goog")) { + throw new IllegalArgumentException( + "Pubsub object name cannot start with goog: " + name); + } + + Matcher match = PUBSUB_NAME_REGEXP.matcher(name); + if (!match.matches()) { + throw new IllegalArgumentException( + "Illegal Pubsub object name specified: " + name + + " Please see Javadoc for naming rules."); + } + } + } + + /** + * A PTransform that continuously reads from a Pubsub stream and + * returns a {@code PCollection} containing the items from + * the stream. + */ + // TODO: Support non-String encodings. + public static class Read { + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * Creates and returns a PubsubIO.Read PTransform for reading from + * a Pubsub topic with the specified publisher topic. Format for + * Cloud Pubsub topic names should be of the form /topics//, + * where is the name of the publishing project. + * The component must comply with the below requirements. + *

    + *
  • Can only contain lowercase letters, numbers, dashes ('-'), underscores ('_') and periods + * ('.').
  • + *
  • Must be between 3 and 255 characters.
  • + *
  • Must begin with a letter.
  • + *
  • Must end with a letter or a number.
  • + *
  • Cannot begin with 'goog' prefix.
  • + *
+ */ + public static Bound topic(String topic) { + return new Bound().topic(topic); + } + + /** + * Creates and returns a PubsubIO.Read PTransform for reading from + * a specific Pubsub subscription. Mutually exclusive with + * PubsubIO.Read.topic(). + * Cloud Pubsub subscription names should be of the form + * /subscriptions//<, + * where is the name of the project the subscription belongs to. + * The component must comply with the below requirements. + *
    + *
  • Can only contain lowercase letters, numbers, dashes ('-'), underscores ('_') and periods + * ('.').
  • + *
  • Must be between 3 and 255 characters.
  • + *
  • Must begin with a letter.
  • + *
  • Must end with a letter or a number.
  • + *
  • Cannot begin with 'goog' prefix.
  • + *
+ */ + public static Bound subscription(String subscription) { + return new Bound().subscription(subscription); + } + + /** + * A PTransform that reads from a PubSub source and returns + * a unbounded PCollection containing the items from the stream. + */ + public static class Bound + extends PTransform> { + /** The Pubsub topic to read from. */ + String topic; + /** The Pubsub subscription to read from */ + String subscription; + + Bound() {} + + Bound(String name, String subscription, String topic) { + super(name); + if (subscription != null) { + Validator.validateSubscriptionName(subscription); + } + if (topic != null) { + Validator.validateTopicName(topic); + } + this.subscription = subscription; + this.topic = topic; + } + + public Bound named(String name) { + return new Bound(name, subscription, topic); + } + + public Bound subscription(String subscription) { + return new Bound(name, subscription, topic); + } + + public Bound topic(String topic) { + return new Bound(name, subscription, topic); + } + + @Override + public PCollection apply(PInput input) { + if (topic == null && subscription == null) { + throw new IllegalStateException( + "need to set either the topic or the subscription for " + + "a PubsubIO.Read transform"); + } + if (topic != null && subscription != null) { + throw new IllegalStateException( + "Can't set both the topic and the subscription for a " + + "PubsubIO.Read transform"); + } + return PCollection.createPrimitiveOutputInternal( + new GlobalWindow()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return StringUtf8Coder.of(); + } + + @Override + protected String getKindString() { return "PubsubIO.Read"; } + + public String getTopic() { + return topic; + } + + public String getSubscription() { + return subscription; + } + + static { + // TODO: Figure out how to make this work under + // DirectPipelineRunner. + } + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A PTransform that continuously writes a + * {@code PCollection} to a Pubsub stream. + */ + // TODO: Support non-String encodings. + public static class Write { + public static Bound named(String name) { + return new Bound().named(name); + } + + /** The topic to publish to. + * Cloud Pubsub topic names should be /topics//, + * where is the name of the publishing project. + */ + public static Bound topic(String topic) { + return new Bound().topic(topic); + } + + /** + * A PTransfrom that writes a unbounded {@code PCollection} + * to a PubSub stream. + */ + public static class Bound + extends PTransform, PDone> { + /** The Pubsub topic to publish to. */ + String topic; + + Bound() {} + + Bound(String name, String topic) { + super(name); + if (topic != null) { + Validator.validateTopicName(topic); + this.topic = topic; + } + } + + public Bound named(String name) { + return new Bound(name, topic); + } + + public Bound topic(String topic) { + return new Bound(name, topic); + } + + @Override + public PDone apply(PCollection input) { + if (topic == null) { + throw new IllegalStateException( + "need to set the topic of a PubsubIO.Write transform"); + } + return new PDone(); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + @Override + protected String getKindString() { return "PubsubIO.Write"; } + + public String getTopic() { + return topic; + } + + static { + // TODO: Figure out how to make this work under + // DirectPipelineRunner. + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java new file mode 100644 index 000000000000..5ab0a99084b8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +/** + * Standard shard naming templates. + * + *

Shard naming templates are strings which may contain placeholders for + * the shard number and shard count. When constructing a filename for a + * particular shard number, the upper-case letters 'S' and 'N' are replaced + * with the 0-padded shard number and shard count respectively. + * + *

Left-padding of the numbers enables lexicographical sorting of the + * resulting filenames. If the shard number or count are too large for the + * space provided in the template, then the result may no longer sort + * lexicographically. For example, a shard template of "S-of-N", for 200 + * shards, will result in outputs named "0-of-200", ... '10-of-200', + * '100-of-200", etc. + * + *

Shard numbers start with 0, so the last shard number is the shard count + * minus one. For example, the template "-SSSSS-of-NNNNN" will be + * instantiated as "-00000-of-01000" for the first shard (shard 0) of a + * 1000-way sharded output. + * + *

A shard name template is typically provided along with a name prefix + * and suffix, which allows constructing complex paths which have embedded + * shard information. For example, outputs in the form + * "gs://bucket/path-01-of-99.txt" could be constructed by providing the + * individual components: + * + *

{@code
+ *   pipeline.apply(
+ *       TextIO.Write.to("gs://bucket/path")
+ *                   .withShardNameTemplate("-SS-of-NN")
+ *                   .withSuffix(".txt"))
+ * }
+ * + *

In the example above, you could make parts of the output configurable + * by users without the user having to specify all components of the output + * name. + * + *

If a shard name template does not contain any repeating 'S', then + * the output shard count must be 1, as otherwise the same filename would be + * generated for multiple shards. + */ +public class ShardNameTemplate { + /** + * Shard name containing the index and max. + * + *

Eg: [prefix]-00000-of-00100[suffix] and + * [prefix]-00001-of-00100[suffix] + */ + public static final String INDEX_OF_MAX = "-SSSSS-of-NNNNN"; + + /** + * Shard is a file within a directory. + * + *

Eg: [prefix]/part-00000[suffix] and [prefix]/part-00001[suffix] + */ + public static final String DIRECTORY_CONTAINER = "/part-SSSSS"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java new file mode 100644 index 000000000000..5d1cb205b422 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java @@ -0,0 +1,567 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.util.CloudSourceUtils.readElemsFromSource; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.worker.TextSink; +import com.google.cloud.dataflow.sdk.runners.worker.TextSource; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; + +import java.io.IOException; +import java.util.List; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Transforms for reading and writing text files. + * + *

To read a {@link PCollection} from one or more text files, use + * {@link TextIO.Read}, specifying {@link TextIO.Read#from} to specify + * the path of the file(s) to read from (e.g., a local filename or + * filename pattern if running locally, or a Google Cloud Storage + * filename or filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link TextIO.Read#named} to specify the name of the pipeline step + * and/or {@link TextIO.Read#withCoder} to specify the Coder to use to + * decode the text lines into Java values. For example: + * + *

 {@code
+ * Pipeline p = ...;
+ *
+ * // A simple Read of a local file (only runs locally):
+ * PCollection lines =
+ *     p.apply(TextIO.Read.from("/path/to/file.txt"));
+ *
+ * // A fully-specified Read from a GCS file (runs locally and via the
+ * // Google Cloud Dataflow service):
+ * PCollection numbers =
+ *     p.apply(TextIO.Read.named("ReadNumbers")
+ *                        .from("gs://my_bucket/path/to/numbers-*.txt")
+ *                        .withCoder(TextualIntegerCoder.of()));
+ * } 
+ * + *

To write a {@link PCollection} to one or more text files, use + * {@link TextIO.Write}, specifying {@link TextIO.Write#to} to specify + * the path of the file to write to (e.g., a local filename or sharded + * filename pattern if running locally, or a Google Cloud Storage + * filename or sharded filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link TextIO.Write#named} to specify the name of the pipeline step + * and/or {@link TextIO.Write#withCoder} to specify the Coder to use + * to encode the Java values into text lines. For example: + * + *

 {@code
+ * // A simple Write to a local file (only runs locally):
+ * PCollection lines = ...;
+ * lines.apply(TextIO.Write.to("/path/to/file.txt"));
+ *
+ * // A fully-specified Write to a sharded GCS file (runs locally and via the
+ * // Google Cloud Dataflow service):
+ * PCollection numbers = ...;
+ * numbers.apply(TextIO.Write.named("WriteNumbers")
+ *                           .to("gs://my_bucket/path/to/numbers")
+ *                           .withSuffix(".txt")
+ *                           .withCoder(TextualIntegerCoder.of()));
+ * } 
+ */ +public class TextIO { + public static final Coder DEFAULT_TEXT_CODER = StringUtf8Coder.of(); + + /** + * A root PTransform that reads from a text file (or multiple text + * files matching a pattern) and returns a PCollection containing + * the decoding of each of the lines of the text file(s). The + * default decoding just returns the lines. + */ + public static class Read { + /** + * Returns a TextIO.Read PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_TEXT_CODER).named(name); + } + + /** + * Returns a TextIO.Read PTransform that reads from the file(s) + * with the given name or pattern. This can be a local filename + * or filename pattern (if running locally), or a Google Cloud + * Storage filename or filename pattern of the form + * {@code "gs:///"}) (if running locally or via + * the Google Cloud Dataflow service). Standard + * Java Filesystem glob patterns ("*", "?", "[..]") are supported. + */ + public static Bound from(String filepattern) { + return new Bound<>(DEFAULT_TEXT_CODER).from(filepattern); + } + + /** + * Returns a TextIO.Read PTransform that uses the given + * {@code Coder} to decode each of the lines of the file into a + * value of type {@code T}. + * + *

By default, uses {@link StringUtf8Coder}, which just + * returns the text lines as Java strings. + * + * @param the type of the decoded elements, and the elements + * of the resulting PCollection + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + // TODO: strippingNewlines, gzipped, etc. + + /** + * A root PTransform that reads from a text file (or multiple text files + * matching a pattern) and returns a bounded PCollection containing the + * decoding of each of the lines of the text file(s). The default + * decoding just returns the lines. + * + * @param the type of each of the elements of the resulting + * PCollection, decoded from the lines of the text file + */ + public static class Bound + extends PTransform> { + /** The filepattern to read from. */ + @Nullable final String filepattern; + + /** The Coder to use to decode each line. */ + @Nullable final Coder coder; + + Bound(Coder coder) { + this(null, null, coder); + } + + Bound(String name, String filepattern, Coder coder) { + super(name); + this.coder = coder; + this.filepattern = filepattern; + } + + /** + * Returns a new TextIO.Read PTransform that's like this one but + * with the given step name. Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filepattern, coder); + } + + /** + * Returns a new TextIO.Read PTransform that's like this one but + * that reads from the file(s) with the given name or pattern. + * (See {@link TextIO.Read#from} for a description of + * filepatterns.) Does not modify this object. + */ + public Bound from(String filepattern) { + return new Bound<>(name, filepattern, coder); + } + + /** + * Returns a new TextIO.Read PTransform that's like this one but + * that uses the given {@code Coder} to decode each of the + * lines of the file into a value of type {@code T1}. Does not + * modify this object. + * + * @param the type of the decoded elements, and the + * elements of the resulting PCollection + */ + public Bound withCoder(Coder coder) { + return new Bound<>(name, filepattern, coder); + } + + @Override + public PCollection apply(PInput input) { + if (filepattern == null) { + throw new IllegalStateException( + "need to set the filepattern of a TextIO.Read transform"); + } + // Force the output's Coder to be what the read is using, and + // unchangeable later, to ensure that we read the input in the + // format specified by the Read transform. + return PCollection.createPrimitiveOutputInternal(new GlobalWindow()) + .setCoder(coder); + } + + @Override + protected Coder getDefaultOutputCoder() { + return coder; + } + + @Override + protected String getKindString() { return "TextIO.Read"; } + + public String getFilepattern() { + return filepattern; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateReadHelper(transform, context); + } + }); + } + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A PTransform that writes a PCollection to a text file (or + * multiple text files matching a sharding pattern), with each + * PCollection element being encoded into its own line. + */ + public static class Write { + /** + * Returns a TextIO.Write PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_TEXT_CODER).named(name); + } + + /** + * Returns a TextIO.Write PTransform that writes to the file(s) + * with the given prefix. This can be a local filename + * (if running locally), or a Google Cloud Storage filename of + * the form {@code "gs:///"}) + * (if running locally or via the Google Cloud Dataflow service). + * + *

The files written will begin with this prefix, followed by + * a shard identifier (see {@link Bound#withNumShards}, and end + * in a common extension, if given by {@link Bound#withSuffix}. + */ + public static Bound to(String prefix) { + return new Bound<>(DEFAULT_TEXT_CODER).to(prefix); + } + + /** + * Returns a TextIO.Write PTransform that writes to the file(s) with the + * given filename suffix. + */ + public static Bound withSuffix(String nameExtension) { + return new Bound<>(DEFAULT_TEXT_CODER).withSuffix(nameExtension); + } + + /** + * Returns a TextIO.Write PTransform that uses the provided shard count. + * + *

Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + */ + public static Bound withNumShards(int numShards) { + return new Bound<>(DEFAULT_TEXT_CODER).withNumShards(numShards); + } + + /** + * Returns a TextIO.Write PTransform that uses the given shard name + * template. + * + *

See {@link ShardNameTemplate} for a description of shard templates. + */ + public static Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(DEFAULT_TEXT_CODER) + .withShardNameTemplate(shardTemplate); + } + + /** + * Returns a TextIO.Write PTransform that forces a single file as + * output. + */ + public static Bound withoutSharding() { + return new Bound<>(DEFAULT_TEXT_CODER).withoutSharding(); + } + + /** + * Returns a TextIO.Write PTransform that uses the given + * {@code Coder} to encode each of the elements of the input + * {@code PCollection} into an output text line. + * + *

By default, uses {@link StringUtf8Coder}, which writes input + * Java strings directly as output lines. + * + * @param the type of the elements of the input PCollection + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + // TODO: appendingNewlines, gzipped, header, footer, etc. + + /** + * A PTransform that writes a bounded PCollection to a text file (or + * multiple text files matching a sharding pattern), with each + * PCollection element being encoded into its own line. + * + * @param the type of the elements of the input PCollection + */ + public static class Bound + extends PTransform, PDone> { + /** The filename to write to. */ + @Nullable final String filenamePrefix; + /** Suffix to use for each filename. */ + final String filenameSuffix; + + /** The Coder to use to decode each line. */ + final Coder coder; + + /** Requested number of shards. 0 for automatic. */ + final int numShards; + + /** Shard template string. */ + final String shardTemplate; + + Bound(Coder coder) { + this(null, null, "", coder, 0, ShardNameTemplate.INDEX_OF_MAX); + } + + Bound(String name, String filenamePrefix, String filenameSuffix, + Coder coder, int numShards, + String shardTemplate) { + super(name); + this.coder = coder; + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * with the given step name. Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * that writes to the file(s) with the given filename prefix. + * + *

See {@link Write#to(String) Write.to(String)} for more information. + * + *

Does not modify this object. + */ + public Bound to(String filenamePrefix) { + validateOutputComponent(filenamePrefix); + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * that writes to the file(s) with the given filename suffix. + * + *

Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withSuffix(String nameExtension) { + validateOutputComponent(nameExtension); + return new Bound<>(name, filenamePrefix, nameExtension, coder, numShards, + shardTemplate); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * that uses the provided shard count. + * + *

Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + *

Does not modify this object. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + * @see ShardNameTemplate + */ + public Bound withNumShards(int numShards) { + Preconditions.checkArgument(numShards >= 0); + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * that uses the given shard name template. + * + *

Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one but + * that forces a single file as output. + * + *

This is a shortcut for + * {@code .withNumShards(1).withShardNameTemplate("")} + * + *

Does not modify this object. + */ + public Bound withoutSharding() { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, 1, ""); + } + + /** + * Returns a new TextIO.Write PTransform that's like this one + * but that uses the given {@code Coder} to encode each of + * the elements of the input {@code PCollection} into an + * output text line. Does not modify this object. + * + * @param the type of the elements of the input PCollection + */ + public Bound withCoder(Coder coder) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate); + } + + @Override + public PDone apply(PCollection input) { + if (filenamePrefix == null) { + throw new IllegalStateException( + "need to set the filename prefix of a TextIO.Write transform"); + } + return new PDone(); + } + + /** + * Returns the current shard name template string. + */ + public String getShardNameTemplate() { + return shardTemplate; + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + @Override + protected String getKindString() { return "TextIO.Write"; } + + public String getFilenamePrefix() { + return filenamePrefix; + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public Coder getCoder() { + return coder; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateWriteHelper(transform, context); + } + }); + } + } + } + + // Pattern which matches old-style shard output patterns, which are now + // disallowed. + private static final Pattern SHARD_OUTPUT_PATTERN = + Pattern.compile("@([0-9]+|\\*)"); + + private static void validateOutputComponent(String partialFilePattern) { + Preconditions.checkArgument( + !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(), + "Output name components are not allowed to contain @* or @N patterns: " + + partialFilePattern); + } + + ////////////////////////////////////////////////////////////////////////////// + + private static void evaluateReadHelper( + Read.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + TextSource source = new TextSource<>( + transform.filepattern, true, null, null, transform.coder); + List elems = readElemsFromSource(source); + context.setPCollection(transform.getOutput(), elems); + } + + private static void evaluateWriteHelper( + Write.Bound transform, + DirectPipelineRunner.EvaluationContext context) { + List elems = context.getPCollection(transform.getInput()); + int numShards = transform.numShards; + if (numShards < 1) { + // System gets to choose. For direct mode, choose 1. + numShards = 1; + } + TextSink> writer = TextSink.createForDirectPipelineRunner( + transform.filenamePrefix, transform.getShardNameTemplate(), + transform.filenameSuffix, numShards, + true, null, null, transform.coder); + try (Sink.SinkWriter> sink = writer.writer()) { + for (T elem : elems) { + sink.add(WindowedValue.valueInGlobalWindow(elem)); + } + } catch (IOException exn) { + throw new RuntimeException( + "unable to write to output file \"" + transform.filenamePrefix + "\"", + exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java new file mode 100644 index 000000000000..886255e271d2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines transforms for reading and writing common storage formats, including + * {@link com.google.cloud.dataflow.sdk.io.AvroIO}, + * {@link com.google.cloud.dataflow.sdk.io.BigQueryIO}, and + * {@link com.google.cloud.dataflow.sdk.io.TextIO}. + * + *

The classes in this package provide {@code Read} transforms which create PCollections + * from existing storage: + *

{@code
+ * PCollection inputData = pipeline.apply(
+ *     BigQueryIO.Read.named("Read")
+ *                    .from("clouddataflow-readonly:samples.weather_stations");
+ * }
+ * and {@code Write} transforms which persist PCollections to external storage: + *
 {@code
+ * PCollection numbers = ...;
+ * numbers.apply(TextIO.Write.named("WriteNumbers")
+ *                           .to("gs://my_bucket/path/to/numbers"));
+ * } 
+ */ +package com.google.cloud.dataflow.sdk.io; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java new file mode 100644 index 000000000000..327e5c08445c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options that allow setting the application name. + */ +public interface ApplicationNameOptions extends PipelineOptions { + /** + * Name of application, for display purposes. + *

+ * Defaults to the name of the class which constructs the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner}. + */ + @Description("Application name. Defaults to the name of the class which " + + "constructs the Pipeline.") + String getAppName(); + void setAppName(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java new file mode 100644 index 000000000000..b764f20918b0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Properties needed when using BigQuery with the Dataflow SDK. + */ +public interface BigQueryOptions extends ApplicationNameOptions, GcpOptions, + PipelineOptions, StreamingOptions { + @Description("Temporary staging dataset ID for BigQuery " + + "table operations") + @Default.String("bigquery.googleapis.com/cloud_dataflow") + String getTempDatasetId(); + void setTempDatasetId(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java new file mode 100644 index 000000000000..cdd5019b5df1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.io.PrintStream; + +/** + * Options which are used to configure the {@link BlockingDataflowPipelineRunner}. + */ +public interface BlockingDataflowPipelineOptions extends DataflowPipelineOptions { + /** + * Output stream for job status messages. + */ + @JsonIgnore + @Default.InstanceFactory(StandardOutputFactory.class) + PrintStream getJobMessageOutput(); + void setJobMessageOutput(PrintStream value); + + /** + * Returns a default of {@link System#out}. + */ + public static class StandardOutputFactory implements DefaultValueFactory { + @Override + public PrintStream create(PipelineOptions options) { + return System.out; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java new file mode 100644 index 000000000000..76de6e6dd8bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.util.List; + +/** + * Options used for testing and debugging the Dataflow SDK. + */ +public interface DataflowPipelineDebugOptions extends PipelineOptions { + /** + * Dataflow endpoint to use. + * + *

Defaults to the current version of the Google Cloud Dataflow + * API, at the time the current SDK version was released. + * + *

If the string contains "://", then this is treated as a url, + * otherwise {@link #getApiRootUrl()} is used as the root + * url. + */ + @Description("Cloud Dataflow Endpoint") + @Default.String("dataflow/v1b3/projects/") + String getDataflowEndpoint(); + void setDataflowEndpoint(String value); + + /** + * The list of backend experiments to enable. + * + *

Dataflow provides a number of experimental features that can be enabled + * with this flag. + * + *

Please sync with the Dataflow team when enabling any experiments. + */ + @Description("Backend experiments to enable.") + List getExperiments(); + void setExperiments(List value); + + /** + * The API endpoint to use when communicating with the Dataflow service. + */ + @Description("Google Cloud root API") + @Default.String("https://www.googleapis.com/") + String getApiRootUrl(); + void setApiRootUrl(String value); + + /** + * The path to write the translated Dataflow specification out to + * at job submission time. + */ + @Description("File for writing dataflow job descriptions") + String getDataflowJobFile(); + void setDataflowJobFile(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java new file mode 100644 index 000000000000..7d0508873232 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.api.services.dataflow.Dataflow; +import com.google.cloud.dataflow.sdk.runners.DataflowPipeline; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.base.MoreObjects; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import org.joda.time.DateTimeUtils; +import org.joda.time.DateTimeZone; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +/** + * Options which can be used to configure the {@link DataflowPipeline}. + */ +public interface DataflowPipelineOptions extends + PipelineOptions, GcpOptions, ApplicationNameOptions, DataflowPipelineDebugOptions, + DataflowPipelineShuffleOptions, DataflowPipelineWorkerPoolOptions, BigQueryOptions, + GcsOptions, StreamingOptions { + + /** + * GCS path for temporary files. + *

+ * Must be a valid Cloud Storage url, beginning with the prefix "gs://" + *

+ * At least one of {@link #getTempLocation()} or {@link #getStagingLocation()} must be set. If + * {@link #getTempLocation()} is not set, then the Dataflow pipeline defaults to using + * {@link #getStagingLocation()}. + */ + @Description("GCS path for temporary files, eg \"gs://bucket/object\". " + + "Defaults to stagingLocation.") + String getTempLocation(); + void setTempLocation(String value); + + /** + * GCS path for staging local files. + *

+ * If {@link #getStagingLocation()} is not set, then the Dataflow pipeline defaults to a staging + * directory within {@link #getTempLocation}. + *

+ * At least one of {@link #getTempLocation()} or {@link #getStagingLocation()} must be set. + */ + @Description("GCS staging path. Defaults to a staging directory" + + " with the tempLocation") + String getStagingLocation(); + void setStagingLocation(String value); + + /** + * The job name is used as an idempotence key within the Dataflow service. If there + * is an existing job which is currently active, another job with the same name will + * not be able to be created. + */ + @Description("Dataflow job name, to uniquely identify active jobs. " + + "Defaults to using the ApplicationName-UserDame-Date.") + @Default.InstanceFactory(JobNameFactory.class) + String getJobName(); + void setJobName(String value); + + /** + * Returns a normalized job name constructed from {@link ApplicationNameOptions#getAppName()}, the + * local system user name (if available), and the current time. The normalization makes sure that + * the job name matches the required pattern of [a-z]([-a-z0-9]*[a-z0-9])? and length limit of 40 + * characters. + *

+ * This job name factory is only able to generate one unique name per second per application and + * user combination. + */ + public static class JobNameFactory implements DefaultValueFactory { + private static final DateTimeFormatter FORMATTER = + DateTimeFormat.forPattern("MMddHHmmss").withZone(DateTimeZone.UTC); + private static final int MAX_APP_NAME = 19; + private static final int MAX_USER_NAME = 9; + + @Override + public String create(PipelineOptions options) { + String appName = options.as(ApplicationNameOptions.class).getAppName(); + String normalizedAppName = appName == null || appName.length() == 0 ? "dataflow" + : appName.toLowerCase() + .replaceAll("[^a-z0-9]", "0") + .replaceAll("^[^a-z]", "a"); + String userName = MoreObjects.firstNonNull(System.getProperty("user.name"), ""); + String normalizedUserName = userName.toLowerCase() + .replaceAll("[^a-z0-9]", "0"); + String datePart = FORMATTER.print(DateTimeUtils.currentTimeMillis()); + + // Maximize the amount of the app name and user name we can use. + normalizedAppName = normalizedAppName.substring(0, + Math.min(normalizedAppName.length(), + MAX_APP_NAME + Math.max(0, MAX_USER_NAME - normalizedUserName.length()))); + normalizedUserName = normalizedUserName.substring(0, + Math.min(userName.length(), + MAX_USER_NAME + Math.max(0, MAX_APP_NAME - normalizedAppName.length()))); + return normalizedAppName + "-" + normalizedUserName + "-" + datePart; + } + } + + /** Alternative Dataflow client */ + @JsonIgnore + @Default.InstanceFactory(DataflowClientFactory.class) + Dataflow getDataflowClient(); + void setDataflowClient(Dataflow value); + + /** Returns the default Dataflow client built from the passed in PipelineOptions. */ + public static class DataflowClientFactory implements DefaultValueFactory { + @Override + public Dataflow create(PipelineOptions options) { + return Transport.newDataflowClient(options.as(DataflowPipelineOptions.class)).build(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineShuffleOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineShuffleOptions.java new file mode 100644 index 000000000000..f59f5eb5d78c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineShuffleOptions.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options for Shuffle workers. Most users should not need to adjust the settings in this section. + */ +public interface DataflowPipelineShuffleOptions { + /** + * Disk source image to use by shuffle VMs for jobs. + * @see Compute Engine Images + */ + @Description("Dataflow shuffle VM disk image.") + String getShuffleDiskSourceImage(); + void setShuffleDiskSourceImage(String value); + + /** + * Number of workers to use with the shuffle appliance, or 0 to use + * the default number of workers. + */ + @Description("Number of shuffle workers, when using remote execution") + int getShuffleNumWorkers(); + void setShuffleNumWorkers(int value); + + /** + * Remote shuffle worker disk size, in gigabytes, or 0 to use the + * default size. + */ + @Description("Remote shuffle worker disk size, in gigabytes, or 0 to use the default size.") + int getShuffleDiskSizeGb(); + void setShuffleDiskSizeGb(int value); + + /** + * GCE availability zone for launching shuffle workers. + * + *

Default is up to the service. + */ + @Description("GCE availability zone for launching shuffle workers. " + + "Default is up to the service") + String getShuffleZone(); + void setShuffleZone(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java new file mode 100644 index 000000000000..6cd983931863 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.util.List; + +/** + * Options which are used to configure the Dataflow pipeline worker pool. + */ +public interface DataflowPipelineWorkerPoolOptions { + /** + * Disk source image to use by VMs for jobs. + * @see Compute Engine Images + */ + @Description("Dataflow VM disk image.") + String getDiskSourceImage(); + void setDiskSourceImage(String value); + + /** + * Number of workers to use in remote execution. + */ + @Description("Number of workers, when using remote execution") + @Default.Integer(3) + int getNumWorkers(); + void setNumWorkers(int value); + + /** + * Remote worker disk size, in gigabytes, or 0 to use the default size. + */ + @Description("Remote worker disk size, in gigabytes, or 0 to use the default size.") + int getDiskSizeGb(); + void setDiskSizeGb(int value); + + /** + * GCE availability zone for launching workers. + * + *

Default is up to the service. + */ + @Description("GCE availability zone for launching workers. " + + "Default is up to the service") + String getZone(); + void setZone(String value); + + /** + * Type of API for handling cluster management,i.e. resizing, healthchecking, etc. + */ + public enum ClusterManagerApiType { + COMPUTE_ENGINE("compute.googleapis.com"), + REPLICA_POOL("replicapool.googleapis.com"); + + private final String apiServiceName; + + private ClusterManagerApiType(String apiServiceName) { + this.apiServiceName = apiServiceName; + } + + public String getApiServiceName() { + return this.apiServiceName; + } + } + + @Description("Type of API for handling cluster management,i.e. resizing, healthchecking, etc.") + @Default.InstanceFactory(ClusterManagerApiTypeFactory.class) + ClusterManagerApiType getClusterManagerApi(); + void setClusterManagerApi(ClusterManagerApiType value); + + /** Returns the default COMPUTE_ENGINE ClusterManagerApiType. */ + public static class ClusterManagerApiTypeFactory implements + DefaultValueFactory { + @Override + public ClusterManagerApiType create(PipelineOptions options) { + return ClusterManagerApiType.COMPUTE_ENGINE; + } + } + + /** + * Machine type to create worker VMs as. + */ + @Description("Dataflow VM machine type for workers.") + String getWorkerMachineType(); + void setWorkerMachineType(String value); + + /** + * Machine type to create VMs as. + */ + @Description("Dataflow VM machine type.") + String getMachineType(); + void setMachineType(String value); + + /** + * List of local files to make available to workers. + *

+ * Jars are placed on the worker's classpath. + *

+ * The default value is the list of jars from the main program's classpath. + */ + @Description("Files to stage on GCS and make available to " + + "workers. The default value is all files from the classpath.") + List getFilesToStage(); + void setFilesToStage(List value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java new file mode 100644 index 000000000000..0b8e1f809cc2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options which are used exclusively within the Dataflow worker harness. + * These options have no effect at pipeline creation time. + */ +public interface DataflowWorkerHarnessOptions extends DataflowPipelineOptions { + /** + * ID of the worker running this pipeline. + */ + String getWorkerId(); + void setWorkerId(String value); + + /** + * ID of the job this pipeline represents. + */ + String getJobId(); + void setJobId(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java new file mode 100644 index 000000000000..321fe744ca49 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * {@link Default} represents a set of annotations which can be used to annotate getter properties + * on {@link PipelineOptions} with information representing the default value to be returned + * if no value is specified. + */ +public @interface Default { + /** + * This represents that the default of the option is the specified {@link java.lang.Class} value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Class { + java.lang.Class value(); + } + + /** + * This represents that the default of the option is the specified {@link java.lang.String} + * value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface String { + java.lang.String value(); + } + + /** + * This represents that the default of the option is the specified boolean primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Boolean { + boolean value(); + } + + /** + * This represents that the default of the option is the specified char primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Character { + char value(); + } + + /** + * This represents that the default of the option is the specified byte primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Byte { + byte value(); + } + /** + * This represents that the default of the option is the specified short primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Short { + short value(); + } + /** + * This represents that the default of the option is the specified int primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Integer { + int value(); + } + + /** + * This represents that the default of the option is the specified long primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Long { + long value(); + } + + /** + * This represents that the default of the option is the specified float primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Float { + float value(); + } + + /** + * This represents that the default of the option is the specified double primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Double { + double value(); + } + + /** + * Value must be of type {@link DefaultValueFactory} and have a default constructor. + * Value is instantiated and then used as a type factory to generate the default. + *

+ * See {@link DefaultValueFactory} for more details. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface InstanceFactory { + java.lang.Class> value(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java new file mode 100644 index 000000000000..18fd7827798c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * An interface used with {@link Default.InstanceFactory} annotation to specify the class which will + * be an instance factory to produce default values for a given getter on {@link PipelineOptions}. + * When a property on a {@link PipelineOptions} is fetched, and is currently unset, the default + * value factory will be instantiated and invoked. + *

+ * Care must be taken to not produce an infinite loop when accessing other fields on the + * {@link PipelineOptions} object. + * + * @param The type of object this factory produces. + */ +public interface DefaultValueFactory { + /** + * Creates a default value for a getter marked with {@link Default.InstanceFactory}. + * + * @param options The current pipeline options. + * @return The default value to be used for the given pipeline options. + */ + T create(PipelineOptions options); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java new file mode 100644 index 000000000000..9de8b1cd2580 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Machine-readable description for options in {@link PipelineOptions}. + */ +@Target(value = ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface Description { + String value(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java new file mode 100644 index 000000000000..85a280d99193 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; + +/** + * Options which can be used to configure the {@link DirectPipeline}. + */ +public interface DirectPipelineOptions extends + ApplicationNameOptions, BigQueryOptions, GcsOptions, GcpOptions, + PipelineOptions, StreamingOptions { + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java new file mode 100644 index 000000000000..7dbaa5fb32d9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.cloud.dataflow.sdk.util.Credentials; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.io.File; +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Options used to configure Google Cloud Platform project and credentials. + *

+ * These options configure which of the following 4 different mechanisms for obtaining a credential + * are used: + *

    + *
  1. + * It can fetch the + * + * application default credentials. + *
  2. + *
  3. + * It can run the gcloud tool in a subprocess to obtain a credential. + * This is the preferred mechanism. The property "GCloudPath" can be + * used to specify where we search for gcloud data. + *
  4. + *
  5. + * The user can specify a client secrets file and go through the OAuth2 + * webflow. The credential will then be cached in the user's home + * directory for reuse. + *
  6. + *
  7. + * The user can specify a file containing a service account private key along + * with the service account name. + *
  8. + *
+ * The default mechanism is to use the + * + * application default credentials falling back to gcloud. The other options can be + * used by setting the corresponding properties. + */ +public interface GcpOptions extends PipelineOptions { + /** + * Project id to use when launching jobs. + */ + @Description("Project id. Required when running a Dataflow in the cloud.") + String getProject(); + void setProject(String value); + + /** + * This option controls which file to use when attempting to create the credentials using the + * OAuth 2 webflow. + */ + @Description("Path to a file containing Google API secret") + String getSecretsFile(); + void setSecretsFile(String value); + + /** + * This option controls which file to use when attempting to create the credentials using the + * service account method. + *

+ * This option if specified, needs be combined with the + * {@link GcpOptions#getServiceAccountName() serviceAccountName}. + */ + @Description("Path to a file containing the P12 service credentials") + String getServiceAccountKeyfile(); + void setServiceAccountKeyfile(String value); + + /** + * This option controls which service account to use when attempting to create the credentials + * using the service account method. + *

+ * This option if specified, needs be combined with the + * {@link GcpOptions#getServiceAccountKeyfile() serviceAccountKeyfile}. + */ + @Description("Name of the service account for Google APIs") + String getServiceAccountName(); + void setServiceAccountName(String value); + + @Description("The path to the gcloud binary. " + + " Default is to search the system path.") + String getGCloudPath(); + void setGCloudPath(String value); + + /** + * Directory for storing dataflow credentials. + */ + @Description("Directory for storing dataflow credentials") + @Default.InstanceFactory(CredentialDirFactory.class) + String getCredentialDir(); + void setCredentialDir(String value); + + /** + * Returns the default credential directory of ${user.home}/.store/data-flow. + */ + public static class CredentialDirFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + File home = new File(System.getProperty("user.home")); + File store = new File(home, ".store"); + File dataflow = new File(store, "data-flow"); + return dataflow.getPath(); + } + } + + @Description("The credential identifier when using a persistent" + + " credential store") + @Default.String("cloud_dataflow") + String getCredentialId(); + void setCredentialId(String value); + + /** Alternative Google Cloud Platform Credential */ + @JsonIgnore + @Description("Google Cloud Platform user credentials.") + @Default.InstanceFactory(GcpUserCredentialsFactory.class) + Credential getGcpCredential(); + void setGcpCredential(Credential value); + + /** + * Attempts to load the user credentials. See + * {@link Credentials#getUserCredential(GcpOptions)} for more details. + */ + public static class GcpUserCredentialsFactory implements DefaultValueFactory { + @Override + public Credential create(PipelineOptions options) { + try { + return Credentials.getUserCredential(options.as(GcpOptions.class)); + } catch (IOException | GeneralSecurityException e) { + throw new RuntimeException("Unable to obtain credential", e); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java new file mode 100644 index 000000000000..543c9cac6c40 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.util.AppEngineEnvironment; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * Options used to configure Google Cloud Storage. + */ +public interface GcsOptions extends + ApplicationNameOptions, GcpOptions, PipelineOptions { + /** Alternative GcsUtil instance */ + @JsonIgnore + @Default.InstanceFactory(GcsUtil.GcsUtilFactory.class) + GcsUtil getGcsUtil(); + void setGcsUtil(GcsUtil value); + + //////////////////////////////////////////////////////////////////////////// + // Allows the user to provide an alternative ExecutorService if their + // environment does not support the default implementation. + @JsonIgnore + @Default.InstanceFactory(ExecutorServiceFactory.class) + ExecutorService getExecutorService(); + void setExecutorService(ExecutorService value); + + /** + * Returns the default {@link ExecutorService} to use within the Dataflow SDK. The + * {@link ExecutorService} is compatible with AppEngine. + */ + public static class ExecutorServiceFactory implements DefaultValueFactory { + @Override + public ExecutorService create(PipelineOptions options) { + ThreadFactoryBuilder threadFactoryBuilder = new ThreadFactoryBuilder(); + threadFactoryBuilder.setThreadFactory(MoreExecutors.platformThreadFactory()); + if (!AppEngineEnvironment.IS_APP_ENGINE) { + // AppEngine doesn't allow modification of threads to be daemon threads. + threadFactoryBuilder.setDaemon(true); + } + /* The SDK requires an unbounded thread pool because a step may create X writers + * each requiring their own thread to perform the writes otherwise a writer may + * block causing deadlock for the step because the writers buffer is full. + * Also, the MapTaskExecutor launches the steps in reverse order and completes + * them in forward order thus requiring enough threads so that each step's writers + * can be active. + */ + return new ThreadPoolExecutor( + 0, Integer.MAX_VALUE, // Allow an unlimited number of re-usable threads. + Long.MAX_VALUE, TimeUnit.NANOSECONDS, // Keep non-core threads alive forever. + new SynchronousQueue(), + threadFactoryBuilder.build()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java new file mode 100644 index 000000000000..d626b90d3c52 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.options.ProxyInvocationHandler.Deserializer; +import com.google.cloud.dataflow.sdk.options.ProxyInvocationHandler.Serializer; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +/** + * Dataflow SDK pipeline configuration options. + *

+ * Serialization + *

+ * For runners which execute their work remotely, every property available within PipelineOptions + * must either be serializable using Jackson's {@link ObjectMapper} or the getter method for the + * property annotated with {@link JsonIgnore @JsonIgnore}. + *

+ * It is an error to have the same property available in multiple interfaces with only some + * of them being annotated with {@link JsonIgnore @JsonIgnore}. It is also an error to mark a + * setter for a property with {@link JsonIgnore @JsonIgnore}. + */ +@JsonSerialize(using = Serializer.class) +@JsonDeserialize(using = Deserializer.class) +public interface PipelineOptions { + /** + * Transforms this object into an object of type . must extend {@link PipelineOptions}. + *

+ * If is not registered with the {@link PipelineOptionsFactory}, then we attempt to + * verify that is composable with every interface that this instance of the PipelineOptions + * has seen. + * + * @param kls The class of the type to transform to. + * @return An object of type kls. + */ + T as(Class kls); + + @Validation.Required + @Description("The runner which will be used when executing the pipeline.") + @Default.Class(DirectPipelineRunner.class) + Class> getRunner(); + void setRunner(Class> kls); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java new file mode 100644 index 000000000000..89a31b07e888 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -0,0 +1,862 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.worker.DataflowWorkerHarness; +import com.google.cloud.dataflow.sdk.testing.TestDataflowPipelineOptions; +import com.google.common.base.Equivalence; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.base.Throwables; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Queues; +import com.google.common.collect.SetMultimap; +import com.google.common.collect.Sets; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.beans.BeanInfo; +import java.beans.IntrospectionException; +import java.beans.Introspector; +import java.beans.PropertyDescriptor; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Proxy; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; + +/** + * Constructs a {@link PipelineOptions} or any derived interface which is composable to any other + * derived interface of {@link PipelineOptions} via the {@link PipelineOptions#as} method. Being + * able to compose one derived interface of {@link PipelineOptions} to another has the following + * restrictions: + *

    + *
  • Any property with the same name must have the same return type for all derived interfaces + * of {@link PipelineOptions}. + *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
  • Every method must conform to being a getter or setter for a JavaBean. + *
  • The derived interface of {@link PipelineOptions} must be composable with every interface + * registered with this factory. + *
+ *

+ * See the JavaBeans + * specification for more details as to what constitutes a property. + */ +public class PipelineOptionsFactory { + + /** + * Creates and returns an object which implements {@link PipelineOptions}. + * This sets the {@link ApplicationNameOptions#getAppName() "appName"} to the calling + * {@link Class#getSimpleName() classes simple name}. + * + * @return An object which implements {@link PipelineOptions}. + */ + public static PipelineOptions create() { + return new Builder(getAppName(3)).as(PipelineOptions.class); + } + + /** + * Creates and returns an object which implements @{code }. + * This sets the {@link ApplicationNameOptions#getAppName() "appName"} to the calling + * {@link Class#getSimpleName() classes simple name}. + *

+ * Note that @{code } must be composable with every registered interface with this factory. + * See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more details. + * + * @return An object which implements @{code }. + */ + public static T as(Class klass) { + return new Builder(getAppName(3)).as(klass); + } + + /** + * Sets the command line arguments to parse when constructing the {@link PipelineOptions}. + *

+ * Example GNU style command line arguments: + *

+   *   --project=MyProject (simple property, will set the "project" property to "MyProject")
+   *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
+   *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
+   *   --x=1 --x=2 --x=3 (list style property, will set the "x" property to [1, 2, 3])
+   *   --x=1,2,3 (shorthand list style property, will set the "x" property to [1, 2, 3])
+   * 
+ * Properties are able to bound to {@link String} and Java primitives @{code boolean}, + * @{code byte}, @{code short}, @{code int}, @{code long}, @{code float}, @{code double} and + * their primitive wrapper classes. + *

+ * List style properties are able to be bound to @{code boolean[]}, @{code char[]}, + * @{code short[]}, @{code int[]}, @{code long[]}, @{code float[]}, @{code double[]}, + * @{code String[]} and @{code List}. + */ + public static Builder fromArgs(String[] args) { + return new Builder(getAppName(3)).fromArgs(args); + } + + /** + * After creation we will validate that {@link PipelineOptions} conforms to all the + * validation criteria from {@code }. See + * {@link PipelineOptionsValidator#validate(Class, PipelineOptions)} for more details about + * validation. + */ + public Builder withValidation() { + return new Builder(getAppName(3)).withValidation(); + } + + /** A fluent PipelineOptions builder. */ + public static class Builder { + private final String defaultAppName; + private final String[] args; + private final boolean validation; + + // Do not allow direct instantiation + private Builder(String defaultAppName) { + this(defaultAppName, null, false); + } + + private Builder(String defaultAppName, String[] args, boolean validation) { + this.defaultAppName = defaultAppName; + this.args = args; + this.validation = validation; + } + + /** + * Sets the command line arguments to parse when constructing the {@link PipelineOptions}. + *

+ * Example GNU style command line arguments: + *

+     *   --project=MyProject (simple property, will set the "project" property to "MyProject")
+     *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
+     *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
+     *   --x=1 --x=2 --x=3 (list style property, will set the "x" property to [1, 2, 3])
+     *   --x=1,2,3 (shorthand list style property, will set the "x" property to [1, 2, 3])
+     * 
+ * Properties are able to bound to {@link String} and Java primitives @{code boolean}, + * @{code byte}, @{code short}, @{code int}, @{code long}, @{code float}, @{code double} and + * their primitive wrapper classes. + *

+ * List style properties are able to be bound to @{code boolean[]}, @{code char[]}, + * @{code short[]}, @{code int[]}, @{code long[]}, @{code float[]}, @{code double[]}, + * @{code String[]} and @{code List}. + */ + public Builder fromArgs(String[] args) { + Preconditions.checkNotNull(args, "Arguments should not be null."); + return new Builder(defaultAppName, args, validation); + } + + /** + * After creation we will validate that {@link PipelineOptions} conforms to all the + * validation criteria from {@code }. See + * {@link PipelineOptionsValidator#validate(Class, PipelineOptions)} for more details about + * validation. + */ + public Builder withValidation() { + return new Builder(defaultAppName, args, true); + } + + /** + * Creates and returns an object which implements {@link PipelineOptions} using the values + * configured on this builder during construction. + * + * @return An object which implements {@link PipelineOptions}. + */ + public PipelineOptions create() { + return as(PipelineOptions.class); + } + + /** + * Creates and returns an object which implements @{code } using the values configured on + * this builder during construction. + *

+ * Note that {@code } must be composable with every registered interface with this factory. + * See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more details. + * + * @return An object which implements @{code }. + */ + public T as(Class klass) { + Map initialOptions = Maps.newHashMap(); + + // Attempt to parse the arguments into the set of initial options to use + if (args != null) { + ListMultimap options = parseCommandLine(args); + LOG.debug("Provided Arguments: {}", options); + initialOptions = parseObjects(klass, options); + } + + // Create our proxy + ProxyInvocationHandler handler = new ProxyInvocationHandler(initialOptions); + T t = handler.as(klass); + + // Set the application name to the default if none was set. + ApplicationNameOptions appNameOptions = t.as(ApplicationNameOptions.class); + if (appNameOptions.getAppName() == null) { + appNameOptions.setAppName(defaultAppName); + } + + if (validation) { + PipelineOptionsValidator.validate(klass, t); + } + return t; + } + } + + /** + * Returns the simple name of calling class at the stack trace {@code level}. + */ + private static String getAppName(int level) { + StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); + try { + return Class.forName(stackTrace[level].getClassName()).getSimpleName(); + } catch (ClassNotFoundException e) { + return "unknown"; + } + } + + /** + * Stores the generated proxyClass and its respective {@link BeanInfo} object. + * + * @param The type of the proxyClass. + */ + static class Registration { + private final Class proxyClass; + private final List propertyDescriptors; + + public Registration(Class proxyClass, List beanInfo) { + this.proxyClass = proxyClass; + this.propertyDescriptors = beanInfo; + } + + List getPropertyDescriptors() { + return propertyDescriptors; + } + + Class getProxyClass() { + return proxyClass; + } + } + + + private static final Logger LOG = LoggerFactory.getLogger(PipelineOptionsFactory.class); + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + private static final ObjectMapper MAPPER = new ObjectMapper(); + + // TODO: Add dynamic registration of pipeline runners. + private static final Map>> + SUPPORTED_PIPELINE_RUNNERS = + ImmutableMap.>>builder() + .put(DirectPipelineRunner.class.getSimpleName(), + DirectPipelineRunner.class) + .put(DataflowPipelineRunner.class.getSimpleName(), + DataflowPipelineRunner.class) + .put(BlockingDataflowPipelineRunner.class.getSimpleName(), + BlockingDataflowPipelineRunner.class) + .build(); + + /** Methods which are ignored when validating the proxy class */ + private static final Set IGNORED_METHODS; + + /** The set of options which have been registered and visible to the user. */ + private static final Set> REGISTERED_OPTIONS = + Sets.newConcurrentHashSet(); + + /** A cache storing a mapping from a given interface to its registration record. */ + private static final Map, Registration> INTERFACE_CACHE = + Maps.newConcurrentMap(); + + /** A cache storing a mapping from a set of interfaces to its registration record. */ + private static final Map>, Registration> COMBINED_CACHE = + Maps.newConcurrentMap(); + + static { + try { + IGNORED_METHODS = ImmutableSet.builder() + .add(Object.class.getMethod("getClass")) + .add(Object.class.getMethod("wait")) + .add(Object.class.getMethod("wait", long.class)) + .add(Object.class.getMethod("wait", long.class, int.class)) + .add(Object.class.getMethod("notify")) + .add(Object.class.getMethod("notifyAll")) + .add(Proxy.class.getMethod("getInvocationHandler", Object.class)) + .build(); + } catch (NoSuchMethodException | SecurityException e) { + LOG.error("Unable to find expected method", e); + throw new ExceptionInInitializerError(e); + } + + // TODO Add support for dynamically loading and registering the options interfaces. + register(PipelineOptions.class); + register(DirectPipelineOptions.class); + register(DataflowPipelineOptions.class); + register(BlockingDataflowPipelineOptions.class); + register(TestDataflowPipelineOptions.class); + } + + /** + * This registers the interface with this factory. This interface must conform to the following + * restrictions: + *

    + *
  • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
  • Every method must conform to being a getter or setter for a JavaBean. + *
  • The derived interface of {@link PipelineOptions} must be composable with every interface + * registered with this factory. + *
+ * + * @param iface The interface object to manually register. + */ + public static synchronized void register(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface(), "Only interface types are supported."); + + if (REGISTERED_OPTIONS.contains(iface)) { + return; + } + validateWellFormed(iface, REGISTERED_OPTIONS); + REGISTERED_OPTIONS.add(iface); + } + + /** + * Validates that the interface conforms to the following: + *
    + *
  • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
  • Every method must conform to being a getter or setter for a JavaBean. + *
  • The derived interface of {@link PipelineOptions} must be composable with every interface + * part of allPipelineOptionsClasses. + *
  • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
  • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
+ * + * @param iface The interface to validate. + * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to + * validate against. + * @return A registration record containing the proxy class and bean info for iface. + */ + static synchronized Registration validateWellFormed( + Class iface, Set> validatedPipelineOptionsInterfaces) { + Preconditions.checkArgument(iface.isInterface(), "Only interface types are supported."); + + Set> combinedPipelineOptionsInterfaces = + FluentIterable.from(validatedPipelineOptionsInterfaces).append(iface).toSet(); + // Validate that the view of all currently passed in options classes is well formed. + if (!COMBINED_CACHE.containsKey(combinedPipelineOptionsInterfaces)) { + Class allProxyClass = Proxy.getProxyClass(PipelineOptionsFactory.class.getClassLoader(), + combinedPipelineOptionsInterfaces.toArray(EMPTY_CLASS_ARRAY)); + try { + List propertyDescriptors = + getPropertyDescriptors(allProxyClass); + validateClass(iface, validatedPipelineOptionsInterfaces, + allProxyClass, propertyDescriptors); + COMBINED_CACHE.put(combinedPipelineOptionsInterfaces, + new Registration((Class) allProxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw Throwables.propagate(e); + } + } + + // Validate that the local view of the class is well formed. + if (!INTERFACE_CACHE.containsKey(iface)) { + Class proxyClass = Proxy.getProxyClass( + PipelineOptionsFactory.class.getClassLoader(), new Class[] {iface}); + try { + List propertyDescriptors = + getPropertyDescriptors(proxyClass); + validateClass(iface, validatedPipelineOptionsInterfaces, proxyClass, propertyDescriptors); + INTERFACE_CACHE.put(iface, + new Registration((Class) proxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw Throwables.propagate(e); + } + } + return (Registration) INTERFACE_CACHE.get(iface); + } + + public static Set> getRegisteredOptions() { + return Collections.unmodifiableSet(REGISTERED_OPTIONS); + } + + static List getPropertyDescriptors( + Set> interfaces) { + return COMBINED_CACHE.get(interfaces).getPropertyDescriptors(); + } + + + /** + * Creates a set of {@link DataflowWorkerHarnessOptions} based of a set of known system + * properties. This is meant to only be used from the {@link DataflowWorkerHarness} as a method to + * bootstrap the worker harness. + * + * @return A {@link DataflowWorkerHarnessOptions} object configured for the + * {@link DataflowWorkerHarness}. + */ + @Deprecated + public static DataflowWorkerHarnessOptions createFromSystemProperties() { + DataflowWorkerHarnessOptions options = as(DataflowWorkerHarnessOptions.class); + options.setRunner(null); + if (System.getProperties().containsKey("root_url")) { + options.setApiRootUrl(System.getProperty("root_url")); + } + if (System.getProperties().containsKey("service_path")) { + options.setDataflowEndpoint(System.getProperty("service_path")); + } + if (System.getProperties().containsKey("temp_gcs_directory")) { + options.setTempLocation(System.getProperty("temp_gcs_directory")); + } + if (System.getProperties().containsKey("service_account_name")) { + options.setServiceAccountName(System.getProperty("service_account_name")); + } + if (System.getProperties().containsKey("service_account_keyfile")) { + options.setServiceAccountKeyfile(System.getProperty("service_account_keyfile")); + } + if (System.getProperties().containsKey("worker_id")) { + options.setWorkerId(System.getProperty("worker_id")); + } + if (System.getProperties().containsKey("project_id")) { + options.setProject(System.getProperty("project_id")); + } + if (System.getProperties().containsKey("job_id")) { + options.setJobId(System.getProperty("job_id")); + } + return options; + } + + /** + * Returns all the methods visible from the provided interfaces. + * + * @param interfaces The interfaces to use when searching for all their methods. + * @return An iterable of {@link Method}s which interfaces expose. + */ + static Iterable getClosureOfMethodsOnInterfaces( + Iterable> interfaces) { + return FluentIterable.from(interfaces).transformAndConcat( + new Function, Iterable>() { + @Override + public Iterable apply(Class input) { + return getClosureOfMethodsOnInterface(input); + } + }); + } + + /** + * Returns all the methods visible from {@code iface}. + * + * @param iface The interface to use when searching for all its methods. + * @return An iterable of {@link Method}s which {@code iface} exposes. + */ + static Iterable getClosureOfMethodsOnInterface(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface()); + ImmutableList.Builder builder = ImmutableList.builder(); + Queue> interfacesToProcess = Queues.newArrayDeque(); + interfacesToProcess.add(iface); + while (!interfacesToProcess.isEmpty()) { + Class current = interfacesToProcess.remove(); + builder.add(current.getMethods()); + interfacesToProcess.addAll(Arrays.asList(current.getInterfaces())); + } + return builder.build(); + } + + /** + * This method is meant to emulate the behavior of {@link Introspector#getBeanInfo(Class, int)} + * to construct the list of {@link PropertyDescriptor}. + *

+ * TODO: Swap back to using Introspector once the proxy class issue with AppEngine is resolved. + */ + private static List getPropertyDescriptors(Class beanClass) + throws IntrospectionException { + // The sorting is important to make this method stable. + SortedSet methods = Sets.newTreeSet(MethodComparator.INSTANCE); + methods.addAll(Arrays.asList(beanClass.getMethods())); + // Build a map of property names to getters. + SortedMap propertyNamesToGetters = Maps.newTreeMap(); + for (Method method : methods) { + String methodName = method.getName(); + if ((!methodName.startsWith("get") + && !methodName.startsWith("is")) + || method.getParameterTypes().length != 0 + || method.getReturnType() == void.class) { + continue; + } + String propertyName = Introspector.decapitalize( + methodName.startsWith("is") ? methodName.substring(2) : methodName.substring(3)); + propertyNamesToGetters.put(propertyName, method); + } + + List descriptors = Lists.newArrayList(); + + /* + * Add all the getter/setter pairs to the list of descriptors removing the getter once + * it has been paired up. + */ + for (Method method : methods) { + String methodName = method.getName(); + if (!methodName.startsWith("set") + || method.getParameterTypes().length != 1 + || method.getReturnType() != void.class) { + continue; + } + String propertyName = Introspector.decapitalize(methodName.substring(3)); + descriptors.add(new PropertyDescriptor( + propertyName, propertyNamesToGetters.remove(propertyName), method)); + } + + // Add the remaining getters with missing setters. + for (Map.Entry getterToMethod : propertyNamesToGetters.entrySet()) { + descriptors.add(new PropertyDescriptor( + getterToMethod.getKey(), getterToMethod.getValue(), null)); + } + return descriptors; + } + + /** + * Validates that a given class conforms to the following properties: + *

    + *
  • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
  • Every method must conform to being a getter or setter for a JavaBean. + *
  • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
  • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
+ * + * @param iface The interface to validate. + * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to + * validate against. + * @param klass The proxy class representing the interface. + * @param descriptors A list of {@link PropertyDescriptor}s to use when validating. + */ + private static void validateClass(Class iface, + Set> validatedPipelineOptionsInterfaces, + Class klass, List descriptors) { + Set methods = Sets.newHashSet(IGNORED_METHODS); + // Ignore static methods, "equals", "hashCode", "toString" and "as" on the generated class. + for (Method method : klass.getMethods()) { + if (Modifier.isStatic(method.getModifiers())) { + methods.add(method); + } + } + try { + methods.add(klass.getMethod("equals", Object.class)); + methods.add(klass.getMethod("hashCode")); + methods.add(klass.getMethod("toString")); + methods.add(klass.getMethod("as", Class.class)); + } catch (NoSuchMethodException | SecurityException e) { + throw Throwables.propagate(e); + } + + // Verify that there are no methods with the same name with two different return types. + Iterable interfaceMethods = FluentIterable + .from(getClosureOfMethodsOnInterface(iface)) + .toSortedSet(MethodComparator.INSTANCE); + SetMultimap, Method> methodNameToMethodMap = + HashMultimap.create(); + for (Method method : interfaceMethods) { + methodNameToMethodMap.put(MethodNameEquivalence.INSTANCE.wrap(method), method); + } + for (Map.Entry, Collection> entry + : methodNameToMethodMap.asMap().entrySet()) { + Set> returnTypes = FluentIterable.from(entry.getValue()) + .transform(ReturnTypeFetchingFunction.INSTANCE).toSet(); + SortedSet collidingMethods = FluentIterable.from(entry.getValue()) + .toSortedSet(MethodComparator.INSTANCE); + Preconditions.checkArgument(returnTypes.size() == 1, + "Method [%s] has multiple definitions %s with different return types for [%s].", + entry.getKey().get().getName(), + collidingMethods, + iface.getName()); + } + + // Verify that there is no getter with a mixed @JsonIgnore annotation and verify + // that no setter has @JsonIgnore. + Iterable allInterfaceMethods = FluentIterable + .from(getClosureOfMethodsOnInterfaces(validatedPipelineOptionsInterfaces)) + .append(getClosureOfMethodsOnInterface(iface)) + .toSortedSet(MethodComparator.INSTANCE); + SetMultimap, Method> methodNameToAllMethodMap = + HashMultimap.create(); + for (Method method : allInterfaceMethods) { + methodNameToAllMethodMap.put(MethodNameEquivalence.INSTANCE.wrap(method), method); + } + for (PropertyDescriptor descriptor : descriptors) { + if (IGNORED_METHODS.contains(descriptor.getReadMethod()) + || IGNORED_METHODS.contains(descriptor.getWriteMethod())) { + continue; + } + Set getters = + methodNameToAllMethodMap.get( + MethodNameEquivalence.INSTANCE.wrap(descriptor.getReadMethod())); + Set gettersWithJsonIgnore = + FluentIterable.from(getters).filter(JsonIgnorePredicate.INSTANCE).toSet(); + + Iterable getterClassNames = FluentIterable.from(getters) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ClassNameFunction.INSTANCE); + Iterable gettersWithJsonIgnoreClassNames = FluentIterable.from(gettersWithJsonIgnore) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ClassNameFunction.INSTANCE); + + Preconditions.checkArgument(gettersWithJsonIgnore.isEmpty() + || getters.size() == gettersWithJsonIgnore.size(), + "Expected getter for property [%s] to be marked with @JsonIgnore on all %s, " + + "found only on %s", + descriptor.getName(), getterClassNames, gettersWithJsonIgnoreClassNames); + + Set settersWithJsonIgnore = FluentIterable.from( + methodNameToAllMethodMap.get( + MethodNameEquivalence.INSTANCE.wrap(descriptor.getWriteMethod()))) + .filter(JsonIgnorePredicate.INSTANCE).toSet(); + + Iterable settersWithJsonIgnoreClassNames = FluentIterable.from(settersWithJsonIgnore) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ClassNameFunction.INSTANCE); + + Preconditions.checkArgument(settersWithJsonIgnore.isEmpty(), + "Expected setter for property [%s] to not be marked with @JsonIgnore on %s", + descriptor.getName(), settersWithJsonIgnoreClassNames); + } + + // Verify that each property has a matching read and write method. + for (PropertyDescriptor propertyDescriptor : descriptors) { + Preconditions.checkArgument( + IGNORED_METHODS.contains(propertyDescriptor.getWriteMethod()) + || propertyDescriptor.getReadMethod() != null, + "Expected getter for property [%s] of type [%s] on [%s].", + propertyDescriptor.getName(), + propertyDescriptor.getPropertyType().getName(), + iface.getName()); + Preconditions.checkArgument( + IGNORED_METHODS.contains(propertyDescriptor.getReadMethod()) + || propertyDescriptor.getWriteMethod() != null, + "Expected setter for property [%s] of type [%s] on [%s].", + propertyDescriptor.getName(), + propertyDescriptor.getPropertyType().getName(), + iface.getName()); + methods.add(propertyDescriptor.getReadMethod()); + methods.add(propertyDescriptor.getWriteMethod()); + } + + // Verify that no additional methods are on an interface that aren't a bean property. + Set unknownMethods = Sets.difference(Sets.newHashSet(klass.getMethods()), methods); + Preconditions.checkArgument(unknownMethods.isEmpty(), + "Methods %s on [%s] do not conform to being bean properties.", + FluentIterable.from(unknownMethods).transform(MethodFormatterFunction.INSTANCE), + iface.getName()); + } + + /** A {@link Comparator} which uses the generic method signature to sort them. */ + private static class MethodComparator implements Comparator { + static final MethodComparator INSTANCE = new MethodComparator(); + @Override + public int compare(Method o1, Method o2) { + return o1.toGenericString().compareTo(o2.toGenericString()); + } + } + + /** A {@link Function} which gets the methods return type. */ + private static class ReturnTypeFetchingFunction implements Function> { + static final ReturnTypeFetchingFunction INSTANCE = new ReturnTypeFetchingFunction(); + @Override + public Class apply(Method input) { + return input.getReturnType(); + } + } + + /** A {@link Function} which turns a method into a simple method signature. */ + private static class MethodFormatterFunction implements Function { + static final MethodFormatterFunction INSTANCE = new MethodFormatterFunction(); + @Override + public String apply(Method input) { + String parameterTypes = FluentIterable.of(input.getParameterTypes()) + .transform(ClassNameFunction.INSTANCE) + .toSortedList(String.CASE_INSENSITIVE_ORDER) + .toString(); + return ClassNameFunction.INSTANCE.apply(input.getReturnType()) + " " + input.getName() + + "(" + parameterTypes.substring(1, parameterTypes.length() - 1) + ")"; + } + } + + /** A {@link Function} with returns the classes name. */ + private static class ClassNameFunction implements Function, String> { + static final ClassNameFunction INSTANCE = new ClassNameFunction(); + @Override + public String apply(Class input) { + return input.getName(); + } + } + + /** A {@link Function} with returns the declaring class for the method. */ + private static class MethodToDeclaringClassFunction implements Function> { + static final MethodToDeclaringClassFunction INSTANCE = new MethodToDeclaringClassFunction(); + @Override + public Class apply(Method input) { + return input.getDeclaringClass(); + } + } + + /** An {@link Equivalence} which considers two methods equivalent if they share the same name. */ + private static class MethodNameEquivalence extends Equivalence { + static final MethodNameEquivalence INSTANCE = new MethodNameEquivalence(); + @Override + protected boolean doEquivalent(Method a, Method b) { + return a.getName().equals(b.getName()); + } + + @Override + protected int doHash(Method t) { + return t.getName().hashCode(); + } + } + + /** + * A {@link Predicate} which returns true if the method is annotated with + * {@link JsonIgnore @JsonIgnore}. + */ + static class JsonIgnorePredicate implements Predicate { + static final JsonIgnorePredicate INSTANCE = new JsonIgnorePredicate(); + @Override + public boolean apply(Method input) { + return input.isAnnotationPresent(JsonIgnore.class); + } + } + + /** + * Splits string arguments based upon expected pattern of --argName=value. + *

+ * Example GNU style command line arguments: + *

+   *   --project=MyProject (simple property, will set the "project" property to "MyProject")
+   *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
+   *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
+   *   --x=1 --x=2 --x=3 (list style property, will set the "x" property to [1, 2, 3])
+   *   --x=1,2,3 (shorthand list style property, will set the "x" property to [1, 2, 3])
+   * 
+ * Properties are able to bound to {@link String} and Java primitives boolean, byte, + * short, int, long, float, double and their primitive wrapper classes. + *

+ * List style properties are able to be bound to boolean[], char[], short[], + * int[], long[], float[], double[], String[] and List. + *

+ */ + private static ListMultimap parseCommandLine(String[] args) { + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + for (String arg : args) { + Preconditions.checkArgument(arg.startsWith("--"), + "Unknown argument %s in command line %s", arg, Arrays.toString(args)); + int index = arg.indexOf("="); + // Make sure that '=' isn't the first character after '--' or the last character + Preconditions.checkArgument(index != 2 && index != arg.length() - 1, + "Unknown argument %s in command line %s", arg, Arrays.toString(args)); + if (index > 0) { + builder.put(arg.substring(2, index), arg.substring(index + 1, arg.length())); + } else { + builder.put(arg.substring(2), "true"); + } + } + return builder.build(); + } + + /** + * Using the parsed string arguments, we convert the strings to the expected + * return type of the methods which are found on the passed in class. + *

+ * For any return type that is expected to be an array or a collection, we further + * split up each string on ','. + *

+ * We special case the "runner" option. It is mapped to the class of the {@link PipelineRunner} + * based off of the {@link PipelineRunner}s simple class name. + */ + private static Map parseObjects( + Class klass, ListMultimap options) { + Map propertyNamesToGetters = Maps.newHashMap(); + PipelineOptionsFactory.validateWellFormed(klass, getRegisteredOptions()); + Iterable propertyDescriptors = + PipelineOptionsFactory.getPropertyDescriptors( + FluentIterable.from(getRegisteredOptions()).append(klass).toSet()); + for (PropertyDescriptor descriptor : propertyDescriptors) { + propertyNamesToGetters.put(descriptor.getName(), descriptor.getReadMethod()); + } + Map convertedOptions = Maps.newHashMap(); + for (Map.Entry> entry : options.asMap().entrySet()) { + if (!propertyNamesToGetters.containsKey(entry.getKey())) { + LOG.warn("Ignoring argument {}={}", entry.getKey(), entry.getValue()); + continue; + } + + Method method = propertyNamesToGetters.get(entry.getKey()); + JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); + if ("runner".equals(entry.getKey())) { + String runner = Iterables.getOnlyElement(entry.getValue()); + Preconditions.checkArgument(SUPPORTED_PIPELINE_RUNNERS.containsKey(runner), + "Unknown 'runner' specified %s, supported pipeline runners %s", + runner, SUPPORTED_PIPELINE_RUNNERS.keySet()); + convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + } else if (method.getReturnType().isArray() + || Collection.class.isAssignableFrom(method.getReturnType())) { + // Split any strings with "," + List values = FluentIterable.from(entry.getValue()) + .transformAndConcat(new Function>() { + @Override + public Iterable apply(String input) { + return Arrays.asList(input.split(",")); + } + }).toList(); + convertedOptions.put(entry.getKey(), MAPPER.convertValue(values, type)); + } else { + String value = Iterables.getOnlyElement(entry.getValue()); + convertedOptions.put(entry.getKey(), MAPPER.convertValue(value, type)); + } + } + return convertedOptions; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java new file mode 100644 index 000000000000..bb7bcf3de831 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.common.base.Preconditions; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +/** + * Validates that the {@link PipelineOptions} conforms to all the {@link Validation} criteria. + */ +public class PipelineOptionsValidator { + /** + * Validates that the passed {@link PipelineOptions} conforms to all the validation criteria from + * the passed in interface. + *

+ * Note that the interface requested must conform to the validation criteria specified on + * {@link PipelineOptions#as(Class)}. + * + * @param klass The interface to fetch validation criteria from. + * @param options The {@link PipelineOptions} to validate. + * @return The type + */ + public static T validate(Class klass, PipelineOptions options) { + Preconditions.checkNotNull(klass); + Preconditions.checkNotNull(options); + Preconditions.checkArgument(Proxy.isProxyClass(options.getClass())); + Preconditions.checkArgument(Proxy.getInvocationHandler(options) + instanceof ProxyInvocationHandler); + + ProxyInvocationHandler handler = + (ProxyInvocationHandler) Proxy.getInvocationHandler(options); + for (Method method : PipelineOptionsFactory.getClosureOfMethodsOnInterface(klass)) { + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof Validation.Required) { + Preconditions.checkArgument(handler.invoke(options, method, null) != null, + "Expected non-null property to be set for [" + method + "]."); + } + } + } + return options.as(klass); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java new file mode 100644 index 000000000000..aefbe1dec294 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java @@ -0,0 +1,390 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory.JsonIgnorePredicate; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory.Registration; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.common.base.Defaults; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ClassToInstanceMap; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.collect.MutableClassToInstanceMap; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import java.beans.PropertyDescriptor; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.lang.reflect.Type; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * Represents and {@link InvocationHandler} for a {@link Proxy}. The invocation handler uses bean + * introspection of the proxy class to store and retrieve values based off of the property name. + *

+ * Unset properties use the {@Default} metadata on the getter to return values. If there + * is no {@Default} annotation on the getter, then a default as + * per the Java Language Specification for the expected return type is returned. + *

+ * In addition to the getter/setter pairs, this proxy invocation handler supports + * {@link Object#equals(Object)}, {@link Object#hashCode()}, {@link Object#toString()} and + * {@link PipelineOptions#as(Class)}. + */ +class ProxyInvocationHandler implements InvocationHandler { + private static final ObjectMapper MAPPER = new ObjectMapper(); + /** + * No two instances of this class are considered equivalent hence we generate a random hash code + * between 0 and {@link Integer#MAX_VALUE}. + */ + private final int hashCode = (int) Math.random() * Integer.MAX_VALUE; + private final Set> knownInterfaces; + private final ClassToInstanceMap interfaceToProxyCache; + private final Map options; + private final Map jsonOptions; + private final Map gettersToPropertyNames; + private final Map settersToPropertyNames; + + ProxyInvocationHandler(Map options) { + this(options, Maps.newHashMap()); + } + + private ProxyInvocationHandler(Map options, Map jsonOptions) { + this.options = options; + this.jsonOptions = jsonOptions; + this.knownInterfaces = new HashSet<>(PipelineOptionsFactory.getRegisteredOptions()); + gettersToPropertyNames = Maps.newHashMap(); + settersToPropertyNames = Maps.newHashMap(); + interfaceToProxyCache = MutableClassToInstanceMap.create(); + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + if (args == null && "toString".equals(method.getName())) { + return toString(); + } else if (args != null && args.length == 1 && "equals".equals(method.getName())) { + return equals(args[0]); + } else if (args == null && "hashCode".equals(method.getName())) { + return hashCode(); + } else if (args != null && "as".equals(method.getName()) && args[0] instanceof Class) { + return as((Class) args[0]); + } + String methodName = method.getName(); + synchronized (this) { + if (gettersToPropertyNames.keySet().contains(methodName)) { + String propertyName = gettersToPropertyNames.get(methodName); + if (!options.containsKey(propertyName)) { + // Lazy bind the default to the method. + Object value = jsonOptions.containsKey(propertyName) + ? getValueFromJson(propertyName, method) + : getDefault((PipelineOptions) proxy, method); + options.put(propertyName, value); + } + return options.get(propertyName); + } else if (settersToPropertyNames.containsKey(methodName)) { + options.put(settersToPropertyNames.get(methodName), args[0]); + return Void.TYPE; + } + } + throw new RuntimeException("Unknown method [" + method + "] invoked with args [" + + Arrays.toString(args) + "]."); + } + + /** + * Backing implementation for {@link PipelineOptions#as(Class)}. + * + * @param iface The interface which the returned object needs to implement. + * @return An object which implements the interface . + */ + synchronized T as(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface()); + if (!interfaceToProxyCache.containsKey(iface)) { + Registration registration = + PipelineOptionsFactory.validateWellFormed(iface, knownInterfaces); + List propertyDescriptors = registration.getPropertyDescriptors(); + Class proxyClass = registration.getProxyClass(); + gettersToPropertyNames.putAll(generateGettersToPropertyNames(propertyDescriptors)); + settersToPropertyNames.putAll(generateSettersToPropertyNames(propertyDescriptors)); + knownInterfaces.add(iface); + interfaceToProxyCache.putInstance(iface, + InstanceBuilder.ofType(proxyClass) + .fromClass(proxyClass) + .withArg(InvocationHandler.class, this) + .build()); + } + return interfaceToProxyCache.getInstance(iface); + } + + + /** + * Returns true if the other object is a ProxyInvocationHandler or is a Proxy object and has the + * same ProxyInvocationHandler as this. + * + * @param obj The object to compare against this. + * @return true iff the other object is a ProxyInvocationHandler or is a Proxy object and has the + * same ProxyInvocationHandler as this. + */ + @Override + public boolean equals(Object obj) { + return obj != null && ((obj instanceof ProxyInvocationHandler && this == obj) + || (Proxy.isProxyClass(obj.getClass()) && this == Proxy.getInvocationHandler(obj))); + } + + /** + * Each instance of this ProxyInvocationHandler is unique and has a random hash code. + * + * @return A hash code that was generated randomly. + */ + @Override + public int hashCode() { + return hashCode; + } + + /** + * This will output all the currently set values. + * + * @return A string representation of this. + */ + @Override + public synchronized String toString() { + StringBuilder b = new StringBuilder(); + b.append("Current Settings:\n"); + for (Map.Entry entry : new TreeMap<>(options).entrySet()) { + b.append(" " + entry.getKey() + ": " + entry.getValue() + "\n"); + } + return b.toString(); + } + + /** + * Uses a Jackson {@link ObjectMapper} to attempt type conversion. + * + * @param method The method whose return type you would like to return. + * @param propertyName The name of the property which is being returned. + * @return An object matching the return type of the method passed in. + */ + private Object getValueFromJson(String propertyName, Method method) { + try { + JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); + JsonNode jsonNode = jsonOptions.get(propertyName); + return MAPPER.readValue(jsonNode.toString(), type); + } catch (IOException e) { + throw new RuntimeException("Unable to parse representation", e); + } + } + + /** + * Returns a default value for the method based upon {@Default} metadata on the getter + * to return values. If there is no {@Default} annotation on the getter, then a default as + * per the Java Language Specification for the expected return type is returned. + * + * @param proxy The proxy object for which we are attempting to get the default. + * @param method The getter method which was invoked. + * @return The default value from an {@link Default} annotation if present, otherwise a default + * value as per the Java Language Specification. + */ + private Object getDefault(PipelineOptions proxy, Method method) { + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof Default.Class) { + return ((Default.Class) annotation).value(); + } else if (annotation instanceof Default.String) { + return ((Default.String) annotation).value(); + } else if (annotation instanceof Default.Boolean) { + return ((Default.Boolean) annotation).value(); + } else if (annotation instanceof Default.Character) { + return ((Default.Character) annotation).value(); + } else if (annotation instanceof Default.Byte) { + return ((Default.Byte) annotation).value(); + } else if (annotation instanceof Default.Short) { + return ((Default.Short) annotation).value(); + } else if (annotation instanceof Default.Integer) { + return ((Default.Integer) annotation).value(); + } else if (annotation instanceof Default.Long) { + return ((Default.Long) annotation).value(); + } else if (annotation instanceof Default.Float) { + return ((Default.Float) annotation).value(); + } else if (annotation instanceof Default.Double) { + return ((Default.Double) annotation).value(); + } else if (annotation instanceof Default.String) { + return ((Default.String) annotation).value(); + } else if (annotation instanceof Default.String) { + return ((Default.String) annotation).value(); + } else if (annotation instanceof Default.String) { + return ((Default.String) annotation).value(); + } else if (annotation instanceof Default.InstanceFactory) { + return InstanceBuilder.ofType(((Default.InstanceFactory) annotation).value()) + .build() + .create(proxy); + } + } + + /* + * We need to make sure that we return something appropriate for the return type. Thus we return + * a default value as defined by the JLS. + */ + return Defaults.defaultValue(method.getReturnType()); + } + + /** + * Returns a map from the getters method name to the name of the property based upon the passed in + * {@link PropertyDescriptor}s property descriptors. + * + * @param propertyDescriptors A list of {@link PropertyDescriptor}s to use when generating the + * map. + * @return A map of getter method name to property name. + */ + private static Map generateGettersToPropertyNames( + List propertyDescriptors) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (PropertyDescriptor descriptor : propertyDescriptors) { + if (descriptor.getReadMethod() != null) { + builder.put(descriptor.getReadMethod().getName(), descriptor.getName()); + } + } + return builder.build(); + } + + /** + * Returns a map from the setters method name to its matching getters method name based upon the + * passed in {@link PropertyDescriptor}s property descriptors. + * + * @param propertyDescriptors A list of {@link PropertyDescriptor}s to use when generating the + * map. + * @return A map of setter method name to getter method name. + */ + private static Map generateSettersToPropertyNames( + List propertyDescriptors) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (PropertyDescriptor descriptor : propertyDescriptors) { + if (descriptor.getWriteMethod() != null) { + builder.put(descriptor.getWriteMethod().getName(), descriptor.getName()); + } + } + return builder.build(); + } + + static class Serializer extends JsonSerializer { + @Override + public void serialize(PipelineOptions value, JsonGenerator jgen, SerializerProvider provider) + throws IOException, JsonProcessingException { + ProxyInvocationHandler handler = (ProxyInvocationHandler) Proxy.getInvocationHandler(value); + Map options = Maps.newHashMap(handler.jsonOptions); + options.putAll(handler.options); + removeIgnoredOptions(handler.knownInterfaces, options); + ensureSerializable(handler.knownInterfaces, options); + jgen.writeStartObject(); + jgen.writeFieldName("options"); + jgen.writeObject(options); + jgen.writeEndObject(); + } + + /** + * We remove all properties within the passed in options where there getter is annotated with + * {@link JsonIgnore @JsonIgnore} from the passed in options using the passed in interfaces. + */ + private void removeIgnoredOptions( + Set> interfaces, Map options) { + // Find all the method names which are annotated with JSON ignore. + Set jsonIgnoreMethodNames = FluentIterable.from( + PipelineOptionsFactory.getClosureOfMethodsOnInterfaces(interfaces)) + .filter(JsonIgnorePredicate.INSTANCE).transform(new Function() { + @Override + public String apply(Method input) { + return input.getName(); + } + }).toSet(); + + // Remove all options which have the same method name as the descriptor. + for (PropertyDescriptor descriptor + : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + if (jsonIgnoreMethodNames.contains(descriptor.getReadMethod().getName())) { + options.remove(descriptor.getName()); + } + } + } + + /** + * We use an {@link ObjectMapper} to verify that the passed in options are serializable + * and deserializable. + */ + private void ensureSerializable(Set> interfaces, + Map options) throws IOException { + // Construct a map from property name to the return type of the getter. + Map propertyToReturnType = Maps.newHashMap(); + for (PropertyDescriptor descriptor + : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + if (descriptor.getReadMethod() != null) { + propertyToReturnType.put(descriptor.getName(), + descriptor.getReadMethod().getGenericReturnType()); + } + } + + // Attempt to serialize and deserialize each property. + for (Map.Entry entry : options.entrySet()) { + String serializedValue = MAPPER.writeValueAsString(entry.getValue()); + JavaType type = MAPPER.getTypeFactory() + .constructType(propertyToReturnType.get(entry.getKey())); + MAPPER.readValue(serializedValue, type); + } + } + } + + static class Deserializer extends JsonDeserializer { + @Override + public PipelineOptions deserialize(JsonParser jp, DeserializationContext ctxt) + throws IOException, JsonProcessingException { + ObjectNode objectNode = (ObjectNode) jp.readValueAsTree(); + ObjectNode optionsNode = (ObjectNode) objectNode.get("options"); + + Map fields = Maps.newHashMap(); + for (Iterator> iterator = optionsNode.fields(); + iterator.hasNext(); ) { + Map.Entry field = iterator.next(); + fields.put(field.getKey(), field.getValue()); + } + PipelineOptions options = + new ProxyInvocationHandler(Maps.newHashMap(), fields) + .as(PipelineOptions.class); + return options; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java new file mode 100644 index 000000000000..725d845d5b9f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * [Whitelisting Required] Options used to configure the streaming backend. + * + *

Important: Streaming support is experimental. It is only supported in the + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner} for users whitelisted in a + * streaming early access program. + * + *

You should expect this class to change significantly in future + * versions of the SDK or be removed entirely. + */ +public interface StreamingOptions extends + ApplicationNameOptions, GcpOptions, PipelineOptions { + /** + * Note that this feature is currently experimental and only available to users whitelisted in + * a streaming early access program. + */ + @Description("True if running in streaming mode (experimental)") + boolean isStreaming(); + void setStreaming(boolean value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java new file mode 100644 index 000000000000..10f205fcadb9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * {@link Validation} represents a set of annotations which can be used to annotate getter + * properties on {@link PipelineOptions} with information representing the validation criteria to + * be used when validating with the {@link PipelineOptionsValidator}. + */ + +public @interface Validation { + /** + * This criteria specifies that the value must be not null. Note that this annotation + * should only be applied to methods which return nullable objects. + */ + @Target(value = ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + public @interface Required { + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java new file mode 100644 index 000000000000..557e377676b7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.options.PipelineOptions} for + * configuring pipeline execution. + * + *

{@link com.google.cloud.dataflow.sdk.options.PipelineOptions} encapsulates the various + * parameters that describe how a pipeline should be run. {@code PipelineOptions} are created + * using a {@link com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory}. + */ +package com.google.cloud.dataflow.sdk.options; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java new file mode 100644 index 000000000000..e27ac0147660 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Provides a simple, powerful model for building both batch and + * streaming parallel data processing + * {@link com.google.cloud.dataflow.sdk.Pipeline}s. + * + *

To use the Google Cloud Dataflow SDK, you build a + * {@link com.google.cloud.dataflow.sdk.Pipeline} which manages a graph of + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s + * and the {@link com.google.cloud.dataflow.sdk.values.PCollection}s that + * the PTransforms consume and produce. + * + *

Each Pipeline has a + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} to specify + * where and how it should run after pipeline construction is complete. + * + */ +package com.google.cloud.dataflow.sdk; + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java new file mode 100644 index 000000000000..61fb09746921 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobState; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * A PipelineRunner that's like {@link DataflowPipelineRunner} + * but that waits for the launched job to finish. + * + *

Prints out job status updates and console messages while it waits. + * + *

Returns the final job state, or throws an exception if the job + * fails or cannot be monitored. + */ +public class BlockingDataflowPipelineRunner extends + PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(BlockingDataflowPipelineRunner.class); + + /** + * Holds the status of a run request. + */ + public static class PipelineJobState implements PipelineResult { + private final JobState state; + + public PipelineJobState(JobState state) { + this.state = state; + } + + public JobState getJobState() { + return state; + } + } + + // Defaults to an infinite wait period. + // TODO: make this configurable after removal of option map. + private static final long BUILTIN_JOB_TIMEOUT_SEC = -1L; + + private DataflowPipelineRunner dataflowPipelineRunner = null; + private MonitoringUtil.JobMessagesHandler jobMessagesHandler; + + protected BlockingDataflowPipelineRunner( + DataflowPipelineRunner internalRunner, + MonitoringUtil.JobMessagesHandler jobMessagesHandler) { + this.dataflowPipelineRunner = internalRunner; + this.jobMessagesHandler = jobMessagesHandler; + } + + /** + * Constructs a runner from the provided options. + */ + public static BlockingDataflowPipelineRunner fromOptions( + PipelineOptions options) { + BlockingDataflowPipelineOptions dataflowOptions = + PipelineOptionsValidator.validate(BlockingDataflowPipelineOptions.class, options); + DataflowPipelineRunner dataflowPipelineRunner = + DataflowPipelineRunner.fromOptions(dataflowOptions); + + return new BlockingDataflowPipelineRunner(dataflowPipelineRunner, + new MonitoringUtil.PrintHandler(dataflowOptions.getJobMessageOutput())); + } + + @Override + public PipelineJobState run(Pipeline p) { + DataflowPipelineJob job = dataflowPipelineRunner.run(p); + + @Nullable JobState result; + try { + result = job.waitToFinish( + BUILTIN_JOB_TIMEOUT_SEC, TimeUnit.SECONDS, jobMessagesHandler); + } catch (IOException | InterruptedException ex) { + throw new RuntimeException("Exception caught during job execution", ex); + } + + if (result == null) { + throw new RuntimeException("No result provided: " + + "possible error requesting job status."); + } + + LOG.info("Job finished with status {}", result); + if (result.isTerminal()) { + return new PipelineJobState(result); + } + + // TODO: introduce an exception which can wrap a JobState, + // so that detailed error information can be retrieved. + throw new RuntimeException("Job failed with state " + result); + } + + @Override + public Output apply( + PTransform transform, Input input) { + return dataflowPipelineRunner.apply(transform, input); + } + + /** + * Sets callbacks to invoke during execution see {@code DataflowPipelineRunnerHooks}. + * Important: setHooks is experimental. Please consult with the Dataflow team before using it. + * You should expect this class to change significantly in future versions of the SDK or be + * removed entirely. + */ + public void setHooks(DataflowPipelineRunnerHooks hooks) { + this.dataflowPipelineRunner.setHooks(hooks); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java new file mode 100644 index 000000000000..310b4d97a323 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; + +/** + * A DataflowPipeline, which returns a + * {@link DataflowPipelineJob} subclass of PipelineResult + * from {@link com.google.cloud.dataflow.sdk.Pipeline#run()}. + */ +public class DataflowPipeline extends Pipeline { + + /** + * Creates and returns a new DataflowPipeline instance for tests. + */ + public static DataflowPipeline create(DataflowPipelineOptions options) { + return new DataflowPipeline(options); + } + + private DataflowPipeline(DataflowPipelineOptions options) { + super(DataflowPipelineRunner.fromOptions(options), options); + } + + @Override + public DataflowPipelineJob run() { + return (DataflowPipelineJob) super.run(); + } + + @Override + public DataflowPipelineRunner getRunner() { + return (DataflowPipelineRunner) super.getRunner(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java new file mode 100644 index 000000000000..c1facb0288b8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java @@ -0,0 +1,169 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobState; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * A DataflowPipelineJob represents a job submitted to Dataflow using + * {@link DataflowPipelineRunner}. + */ +public class DataflowPipelineJob implements PipelineResult { + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineJob.class); + + /** + * The id for the job. + */ + private String jobId; + + /** + * Google cloud project to associate this pipeline with. + */ + private String project; + + /** + * Client for the Dataflow service. This can be used to query the service + * for information about the job. + */ + private Dataflow dataflowClient; + + /** + * Construct the job. + * + * @param projectId the project id + * @param jobId the job id + * @param client the workflow client + */ + public DataflowPipelineJob( + String projectId, String jobId, Dataflow client) { + project = projectId; + this.jobId = jobId; + dataflowClient = client; + } + + public String getJobId() { + return jobId; + } + + public String getProjectId() { + return project; + } + + public Dataflow getDataflowClient() { + return dataflowClient; + } + + /** + * Wait for the job to finish and return the final status. + * + * @param timeToWait The time to wait in units timeUnit for the job to finish. + * @param timeUnit The unit of time for timeToWait. + * Provide a negative value for an infinite wait. + * @param messageHandler If non null this handler will be invoked for each + * batch of messages received. + * @return The final state of the job or null on timeout or if the + * thread is interrupted. + * @throws IOException If there is a persistent problem getting job + * information. + * @throws InterruptedException + */ + @Nullable + public JobState waitToFinish( + long timeToWait, + TimeUnit timeUnit, + MonitoringUtil.JobMessagesHandler messageHandler) + throws IOException, InterruptedException { + // The polling interval for job status information. + long interval = TimeUnit.SECONDS.toMillis(2); + + // The time at which to stop. + long endTime = timeToWait >= 0 + ? System.currentTimeMillis() + timeUnit.toMillis(timeToWait) + : Long.MAX_VALUE; + + MonitoringUtil monitor = new MonitoringUtil(project, dataflowClient); + + long lastTimestamp = 0; + int errorGettingMessages = 0; + int errorGettingJobStatus = 0; + while (true) { + if (System.currentTimeMillis() >= endTime) { + // Timed out. + return null; + } + + if (messageHandler != null) { + // Process all the job messages that have accumulated so far. + try { + List allMessages = monitor.getJobMessages( + jobId, lastTimestamp); + + if (!allMessages.isEmpty()) { + lastTimestamp = + fromCloudTime(allMessages.get(allMessages.size() - 1).getTime()).getMillis(); + messageHandler.process(allMessages); + } + } catch (GoogleJsonResponseException | SocketTimeoutException e) { + if (++errorGettingMessages > 5) { + // We want to continue to wait for the job to finish so + // we ignore this error, but warn occasionally if it keeps happening. + LOG.warn("There are problems accessing job messages: ", e); + errorGettingMessages = 0; + } + } + } + + // Check if the job is done. + try { + Job job = dataflowClient.v1b3().projects().jobs().get(project, jobId).execute(); + JobState state = JobState.toState(job.getCurrentState()); + if (state.isTerminal()) { + return state; + } + } catch (GoogleJsonResponseException | SocketTimeoutException e) { + if (++errorGettingJobStatus > 5) { + // We want to continue to wait for the job to finish so + // we ignore this error, but warn occasionally if it keeps happening. + LOG.warn("There were problems getting job status: ", e); + errorGettingJobStatus = 0; + } + } + + // Job not yet done. Wait a little, then check again. + long sleepTime = Math.min( + endTime - System.currentTimeMillis(), interval); + TimeUnit.MILLISECONDS.sleep(sleepTime); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java new file mode 100644 index 000000000000..ed01b8345c18 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java @@ -0,0 +1,315 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.util.Joiner; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.DataflowReleaseInfo; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.PackageUtil; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.PrintWriter; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A {@link PipelineRunner} that executes the operations in the + * pipeline by first translating them to the Dataflow representation + * using the {@link DataflowPipelineTranslator} and then submitting + * them to a Dataflow service for execution. + */ +public class DataflowPipelineRunner extends PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineRunner.class); + + /** Provided configuration options. */ + private final DataflowPipelineOptions options; + + /** The directory on GCS where files should be uploaded. */ + private final GcsPath gcsStaging; + + /** The directory on GCS where temporary files are stored. */ + private final GcsPath gcsTemp; + + /** Client for the Dataflow service. This is used to actually submit jobs. */ + private final Dataflow dataflowClient; + + /** Translator for this DataflowPipelineRunner, based on options. */ + private final DataflowPipelineTranslator translator; + + /** A set of user defined functions to invoke at different points in execution. */ + private DataflowPipelineRunnerHooks hooks; + + // Environment version information + private static final String ENVIRONMENT_MAJOR_VERSION = "0"; + + /** + * Construct a runner from the provided options. + * + * @param options Properties which configure the runner. + * @return The newly created runner. + */ + public static DataflowPipelineRunner fromOptions(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = + PipelineOptionsValidator.validate(DataflowPipelineOptions.class, options); + ArrayList missing = new ArrayList<>(); + + if (dataflowOptions.getProject() == null) { + missing.add("project"); + } + if (dataflowOptions.getAppName() == null) { + missing.add("appName"); + } + if (missing.size() > 0) { + throw new IllegalArgumentException( + "Missing required values: " + Joiner.on(',').join(missing)); + } + + Preconditions.checkArgument(!(Strings.isNullOrEmpty(dataflowOptions.getTempLocation()) + && Strings.isNullOrEmpty(dataflowOptions.getStagingLocation())), + "Missing required value: at least one of tempLocation or stagingLocation must be set."); + if (Strings.isNullOrEmpty(dataflowOptions.getTempLocation())) { + dataflowOptions.setTempLocation(dataflowOptions.getStagingLocation()); + } else if (Strings.isNullOrEmpty(dataflowOptions.getStagingLocation())) { + dataflowOptions.setStagingLocation( + GcsPath.fromUri(dataflowOptions.getTempLocation()).resolve("staging").toString()); + } + + if (dataflowOptions.getFilesToStage() == null) { + dataflowOptions.setFilesToStage(detectClassPathResourcesToStage( + DataflowPipelineRunner.class.getClassLoader())); + LOG.info("No specified files to stage. Defaulting to files: {}", + dataflowOptions.getFilesToStage()); + } + + // Verify jobName according to service requirements. + String jobName = dataflowOptions.getJobName().toLowerCase(); + Preconditions.checkArgument( + jobName.matches("[a-z]([-a-z0-9]*[a-z0-9])?"), + "JobName invalid; the name must consist of only the characters " + + "[-a-z0-9], starting with a letter and ending with a letter " + + "or number"); + Preconditions.checkArgument(jobName.length() <= 40, + "JobName too long; must be no more than 40 characters in length"); + + return new DataflowPipelineRunner(dataflowOptions); + } + + private DataflowPipelineRunner(DataflowPipelineOptions options) { + this.options = options; + this.dataflowClient = options.getDataflowClient(); + this.gcsTemp = GcsPath.fromUri(options.getTempLocation()); + this.gcsStaging = GcsPath.fromUri(options.getStagingLocation()); + this.translator = DataflowPipelineTranslator.fromOptions(options); + + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(options); + } + + @Override + @SuppressWarnings("unchecked") + public Output apply( + PTransform transform, Input input) { + if (transform instanceof Combine.GroupedValues) { + // TODO: Redundant with translator registration? + return (Output) PCollection.createPrimitiveOutputInternal( + ((PCollection) input).getWindowingFn()); + } else if (transform instanceof GroupByKey) { + // The DataflowPipelineRunner implementation of GroupByKey will sort values by timestamp, + // so no need for an explicit sort transform. + boolean runnerSortsByTimestamp = true; + return (Output) ((GroupByKey) transform).applyHelper( + (PCollection) input, options.isStreaming(), runnerSortsByTimestamp); + } else { + return super.apply(transform, input); + } + } + + @Override + public DataflowPipelineJob run(Pipeline pipeline) { + LOG.info("Executing pipeline on the Dataflow Service, which will have billing implications " + + "related to Google Compute Engine usage and other Google Cloud Services."); + + GcsUtil gcsUtil = options.getGcsUtil(); + List packages = + PackageUtil.stageClasspathElementsToGcs(gcsUtil, options.getFilesToStage(), gcsStaging); + + Job newJob = translator.translate(pipeline, packages); + + String version = DataflowReleaseInfo.getReleaseInfo().getVersion(); + System.out.println("Dataflow SDK version: " + version); + + newJob.getEnvironment().setUserAgent(DataflowReleaseInfo.getReleaseInfo()); + // The Dataflow Service may write to the temporary directory directly, so + // must be verified. + newJob.getEnvironment().setTempStoragePrefix(verifyGcsPath(gcsTemp).toResourceName()); + newJob.getEnvironment().setDataset(options.getTempDatasetId()); + newJob.getEnvironment().setClusterManagerApiService( + options.getClusterManagerApi().getApiServiceName()); + newJob.getEnvironment().setExperiments(options.getExperiments()); + + // Requirements about the service. + Map environmentVersion = new HashMap<>(); + // TODO: Specify the environment major version. + // environmentVersion.put(PropertyNames.ENVIRONMENT_VERSION_MAJOR_KEY, + // ENVIRONMENT_MAJOR_VERSION); + newJob.getEnvironment().setVersion(environmentVersion); + // Default jobType is DATA_PARALLEL which is for java batch. + String jobType = "DATA_PARALLEL"; + + if (options.isStreaming()) { + jobType = "STREAMING"; + } + environmentVersion.put(PropertyNames.ENVIRONMENT_VERSION_JOB_TYPE_KEY, jobType); + + if (hooks != null) { + hooks.modifyEnvironmentBeforeSubmission(newJob.getEnvironment()); + } + + if (!Strings.isNullOrEmpty(options.getDataflowJobFile())) { + try (PrintWriter printWriter = new PrintWriter( + new File(options.getDataflowJobFile()))) { + String workSpecJson = DataflowPipelineTranslator.jobToString(newJob); + printWriter.print(workSpecJson); + LOG.info("Printed workflow specification to {}", options.getDataflowJobFile()); + } catch (JsonProcessingException ex) { + LOG.warn("Cannot translate workflow spec to json for debug."); + } catch (FileNotFoundException ex) { + LOG.warn("Cannot create workflow spec output file."); + } + } + + Job jobResult; + try { + jobResult = dataflowClient.v1b3().projects().jobs() + .create(options.getProject(), newJob) + .execute(); + } catch (GoogleJsonResponseException e) { + throw new RuntimeException( + "Failed to create a workflow job: " + + (e.getDetails() != null ? e.getDetails().getMessage() : e), e); + } catch (IOException e) { + throw new RuntimeException("Failed to create a workflow job", e); + } + + LOG.info("To access the Dataflow monitoring console, please navigate to {}", + MonitoringUtil.getJobMonitoringPageURL(options.getProject(), jobResult.getId())); + System.out.println("Submitted job: " + jobResult.getId()); + + // Use a raw client for post-launch monitoring, as status calls may fail + // regularly and need not be retried automatically. + return new DataflowPipelineJob(options.getProject(), jobResult.getId(), + Transport.newRawDataflowClient(options).build()); + } + + /** + * Returns the DataflowPipelineTranslator associated with this object. + */ + public DataflowPipelineTranslator getTranslator() { + return translator; + } + + /** + * Sets callbacks to invoke during execution see {@code DataflowPipelineRunnerHooks}. + * Important: setHooks is experimental. Please consult with the Dataflow team before using it. + * You should expect this class to change significantly in future versions of the SDK or be + * removed entirely. + */ + public void setHooks(DataflowPipelineRunnerHooks hooks) { + this.hooks = hooks; + } + + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public String toString() { return "DataflowPipelineRunner#" + hashCode(); } + + /** + * Verifies that a path can be used by the Dataflow Service API. + * @return the supplied path + */ + public static GcsPath verifyGcsPath(GcsPath path) { + Preconditions.checkArgument(path.isAbsolute(), + "Must provide absolute paths for Dataflow"); + Preconditions.checkArgument(!path.getObject().contains("//"), + "Dataflow Service does not allow objects with consecutive slashes"); + return path; + } + + /** + * Attempts to detect all the resources the class loader has access to. This does not recurse + * to class loader parents stopping it from pulling in resources from the system class loader. + * + * @param classLoader The URLClassLoader to use to detect resources to stage. + * @throws IllegalArgumentException If either the class loader is not a URLClassLoader or one + * of the resources the class loader exposes is not a file resource. + * @return A list of absolute paths to the resources the class loader uses. + */ + protected static List detectClassPathResourcesToStage(ClassLoader classLoader) { + if (!(classLoader instanceof URLClassLoader)) { + String message = String.format("Unable to use ClassLoader to detect classpath elements. " + + "Current ClassLoader is %s, only URLClassLoaders are supported.", classLoader); + LOG.error(message); + throw new IllegalArgumentException(message); + } + + List files = new ArrayList<>(); + for (URL url : ((URLClassLoader) classLoader).getURLs()) { + try { + files.add(new File(url.toURI()).getAbsolutePath()); + } catch (IllegalArgumentException | URISyntaxException e) { + String message = String.format("Unable to convert url (%s) to file.", url); + LOG.error(message); + throw new IllegalArgumentException(message, e); + } + } + return files; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java new file mode 100644 index 000000000000..ba822e876e48 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.api.services.dataflow.model.Environment; + +/** + * An instance of this class can be passed to the + * DataflowPipeline runner to add user defined hooks to be + * invoked at various times during pipeline execution. + * + * Important: DataflowPipelineRunnerHooks is experimental. Please consult with + * the Dataflow team before using it. You should expect this class to change significantly + * in future versions of the SDK or be removed entirely. + * + */ +public class DataflowPipelineRunnerHooks { + /** + * Allows the user to modify the environment of their job before their job is submitted + * to the service for execution. + * + * @param environment The environment of the job. Users can make change to this instance in order + * to change the environment with which their job executes on the service. + */ + public void modifyEnvironmentBeforeSubmission(Environment environment) {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java new file mode 100644 index 000000000000..6f39a2bae5b8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -0,0 +1,963 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; +import static com.google.cloud.dataflow.sdk.util.Structs.addDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.addList; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.client.util.Preconditions; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Disk; +import com.google.api.services.dataflow.model.Environment; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.Step; +import com.google.api.services.dataflow.model.TaskRunnerSettings; +import com.google.api.services.dataflow.model.WorkerPool; +import com.google.api.services.dataflow.model.WorkerSettings; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.DatastoreIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.runners.dataflow.AvroIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.BigQueryIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.DatastoreIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.TextIOTranslator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.OutputReference; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import com.google.common.base.Strings; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * DataflowPipelineTranslator knows how to translate Pipeline objects + * into Dataflow API Jobs. + */ +public class DataflowPipelineTranslator { + // Must be kept in sync with their internal counterparts. + public static final String HARNESS_WORKER_POOL = "harness"; + public static final String SHUFFLE_WORKER_POOL = "shuffle"; + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineTranslator.class); + + /** + * A map from PTransform class to the corresponding + * TransformTranslator to use to translate that transform. + * + *

A static map that contains system-wide defaults. + */ + private static Map transformTranslators = + new HashMap<>(); + + /** Provided configuration options. */ + private final DataflowPipelineOptions options; + + /** + * Constructs a translator from the provided options. + * + * @param options Properties which configure the translator. + * + * @return The newly created translator. + */ + public static DataflowPipelineTranslator fromOptions( + DataflowPipelineOptions options) { + return new DataflowPipelineTranslator(options); + } + + private DataflowPipelineTranslator(DataflowPipelineOptions options) { + this.options = options; + } + + /** + * Translates a Pipeline into a Job + */ + public Job translate(Pipeline pipeline, List packages) { + Translator translator = new Translator(pipeline); + return translator.translate(packages); + } + + public static String jobToString(Job job) + throws JsonProcessingException { + return new ObjectMapper().writerWithDefaultPrettyPrinter() + .writeValueAsString(job); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Records that instances of the specified PTransform class + * should be translated by default by the corresponding + * TransformTranslator. + */ + public static void registerTransformTranslator( + Class transformClass, + TransformTranslator transformTranslator) { + if (transformTranslators.put(transformClass, transformTranslator) != null) { + throw new IllegalArgumentException( + "defining multiple translators for " + transformClass); + } + } + + /** + * Returns the TransformTranslator to use for instances of the + * specified PTransform class, or null if none registered. + */ + @SuppressWarnings("unchecked") + public + TransformTranslator getTransformTranslator(Class transformClass) { + return transformTranslators.get(transformClass); + } + + /** + * An translator of a PTransform. + */ + public interface TransformTranslator { + public void translate(PT transform, + TranslationContext context); + } + + /** + * The interface provided to registered callbacks for interacting + * with the DataflowPipelineRunner, including reading and writing the + * values of PCollections and side inputs ({@link PCollectionViews}). + */ + public interface TranslationContext { + /** + * Returns the configured pipeline options. + */ + DataflowPipelineOptions getPipelineOptions(); + + /** + * Adds a step to the Dataflow workflow for the given transform, with + * the given Dataflow step type. + * This step becomes "current" for the purpose of {@link #addInput} and + * {@link #addOutput}. + */ + public void addStep(PTransform transform, String type); + + /** + * Adds a pre-defined step to the Dataflow workflow. The given PTransform should be + * consistent with the Step, in terms of input, output and coder types. + * + *

This is a low-level operation, when using this method it is up to + * the caller to ensure that names do not collide. + */ + public void addStep(PTransform transform, Step step); + + /** + * Sets the encoding for the current Dataflow step. + */ + public void addEncodingInput(Coder value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + */ + public void addInput(String name, String value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + */ + public void addInput(String name, Long value); + + /** + * Adds an input with the given name to the previously added Dataflow + * step, coming from the specified input PValue. + */ + public void addInput(String name, PInput value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + * + *

This applies any verification of paths required by the Dataflow + * service. + */ + public void addInput(String name, GcsPath path); + + /** + * Adds an input which is a dictionary of strings to objects. + */ + public void addInput(String name, Map elements); + + /** + * Adds an input which is a list of objects. + */ + public void addInput(String name, List> elements); + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue}, + * including its {@code Coder} if a {@code TypedPValue}. If the + * {@code PValue} is a {@code PCollection}, wraps its coder inside + * a {@code WindowedValueCoder}. + */ + public void addOutput(String name, PValue value); + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue}, + * including its {@code Coder} if a {@code TypedPValue}. If the + * {@code PValue} is a {@code PCollection}, wraps its coder inside + * a {@code ValueOnlyCoder}. + */ + public void addValueOnlyOutput(String name, PValue value); + + /** + * Adds an output with the given name to the previously added + * CollectionToSingleton Dataflow step, consuming the specified + * input {@code PValue} and producing the specified output + * {@code PValue}. This step requires special treatment for its + * output encoding. + */ + public void addCollectionToSingletonOutput(String name, + PValue inputValue, + PValue outputValue); + + /** + * Encode a PValue reference as an output reference. + */ + public OutputReference asOutputReference(PValue value); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Translates a Pipeline into the Dataflow representation. + */ + class Translator implements PipelineVisitor, TranslationContext { + /** The Pipeline to translate. */ + private final Pipeline pipeline; + + /** The Cloud Dataflow Job representation. */ + private final Job job = new Job(); + + /** + * Translator is stateful, as addProperty calls refer to the current step. + */ + private Step currentStep; + + /** + * A Map from PTransforms to their unique Dataflow step names. + */ + private final Map stepNames = new HashMap<>(); + + /** + * A Map from PValues to their output names used by their producer + * Dataflow steps. + */ + private final Map outputNames = new HashMap<>(); + + /** + * A Map from PValues to the Coders used for them. + */ + private final Map> outputCoders = new HashMap<>(); + + /** + * Constructs a Translator that will translate the specified + * Pipeline into Dataflow objects. + */ + public Translator(Pipeline pipeline) { + this.pipeline = pipeline; + } + + /** + * Translates this Translator's pipeline onto its writer. + * @return a Job definition filled in with the type of job, the environment, + * and the job steps. + */ + public Job translate(List packages) { + job.setName(options.getJobName().toLowerCase()); + + Environment environment = new Environment(); + job.setEnvironment(environment); + + WorkerPool workerPool = new WorkerPool(); + + workerPool.setKind(HARNESS_WORKER_POOL); + + // Pass the URL and endpoint to use to the worker pool. + WorkerSettings workerSettings = new WorkerSettings(); + workerSettings.setBaseUrl(options.getApiRootUrl()); + workerSettings.setServicePath(options.getDataflowEndpoint()); + + TaskRunnerSettings taskRunnerSettings = new TaskRunnerSettings(); + taskRunnerSettings.setParallelWorkerSettings(workerSettings); + + workerPool.setTaskrunnerSettings(taskRunnerSettings); + + WorkerPool shufflePool = new WorkerPool(); + shufflePool.setKind(SHUFFLE_WORKER_POOL); + + if (options.isStreaming()) { + job.setType("JOB_TYPE_STREAMING"); + } else { + job.setType("JOB_TYPE_BATCH"); + } + + if (options.getWorkerMachineType() != null) { + workerPool.setMachineType(options.getWorkerMachineType()); + } + + workerPool.setPackages(packages); + workerPool.setNumWorkers(options.getNumWorkers()); + shufflePool.setNumWorkers(options.getNumWorkers()); + if (options.getDiskSourceImage() != null) { + workerPool.setDiskSourceImage(options.getDiskSourceImage()); + shufflePool.setDiskSourceImage(options.getDiskSourceImage()); + } + + if (options.getMachineType() != null) { + workerPool.setMachineType(options.getMachineType()); + } + if (options.isStreaming()) { + // Use separate data disk for streaming. + Disk disk = new Disk(); + disk.setSizeGb(10); + disk.setDiskType( + // TODO: Fill in the project and zone. + "compute.googleapis.com/projects//zones//diskTypes/pd-standard"); + // TODO: introduce a separate location for Windmill binary in the + // TaskRunner so it wouldn't interfere with the data disk mount point. + disk.setMountPoint("/windmill"); + workerPool.setDataDisks(Collections.singletonList(disk)); + } + if (!Strings.isNullOrEmpty(options.getZone())) { + workerPool.setZone(options.getZone()); + shufflePool.setZone(options.getZone()); + } + if (options.getDiskSizeGb() > 0) { + workerPool.setDiskSizeGb(options.getDiskSizeGb()); + shufflePool.setDiskSizeGb(options.getDiskSizeGb()); + } + + // Set up any specific shuffle pool parameters + if (options.getShuffleNumWorkers() > 0) { + shufflePool.setNumWorkers(options.getShuffleNumWorkers()); + } + if (options.getShuffleDiskSourceImage() != null) { + shufflePool.setDiskSourceImage(options.getShuffleDiskSourceImage()); + } + if (!Strings.isNullOrEmpty(options.getShuffleZone())) { + shufflePool.setZone(options.getShuffleZone()); + } + if (options.getShuffleDiskSizeGb() > 0) { + shufflePool.setDiskSizeGb(options.getShuffleDiskSizeGb()); + } + + List workerPools = new LinkedList<>(); + + workerPools.add(workerPool); + if (!options.isStreaming()) { + workerPools.add(shufflePool); + } + environment.setWorkerPools(workerPools); + + pipeline.traverseTopologically(this); + return job; + } + + @Override + public DataflowPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @SuppressWarnings("unchecked") + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + TransformTranslator translator = + getTransformTranslator(transform.getClass()); + if (translator == null) { + throw new IllegalStateException( + "no translator registered for " + transform); + } + LOG.debug("Translating {}", transform); + translator.translate(transform, this); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + LOG.debug("Checking translation of {}", value); + if (options.isStreaming() + && value instanceof PCollectionView) { + throw new UnsupportedOperationException( + "PCollectionViews are not supported in streaming Dataflow."); + } + if (value.getProducingTransformInternal() == null) { + throw new RuntimeException( + "internal error: expecting a PValue " + + "to have a producingTransform"); + } + if (!producer.isCompositeNode()) { + // Primitive transforms are the only ones assigned step names. + asOutputReference(value); + } + } + + @Override + public void addStep(PTransform transform, String type) { + String stepName = genStepName(); + if (stepNames.put(transform, stepName) != null) { + throw new IllegalArgumentException( + transform + " already has a name specified"); + } + // Start the next "steps" list item. + List steps = job.getSteps(); + if (steps == null) { + steps = new LinkedList<>(); + job.setSteps(steps); + } + + currentStep = new Step(); + currentStep.setName(stepName); + currentStep.setKind(type); + steps.add(currentStep); + addInput(PropertyNames.USER_NAME, pipeline.getFullName(transform)); + } + + @Override + public void addStep(PTransform transform, Step original) { + Step step = original.clone(); + String stepName = step.getName(); + if (stepNames.put(transform, stepName) != null) { + throw new IllegalArgumentException(transform + " already has a name specified"); + } + + Map properties = step.getProperties(); + if (properties != null) { + @Nullable List> outputInfoList = null; + try { + // TODO: This should be done via a Structs accessor. + outputInfoList = (List>) properties.get(PropertyNames.OUTPUT_INFO); + } catch (Exception e) { + throw new RuntimeException("Inconsistent dataflow pipeline translation", e); + } + if (outputInfoList != null && outputInfoList.size() > 0) { + Map firstOutputPort = outputInfoList.get(0); + @Nullable String name; + try { + name = getString(firstOutputPort, PropertyNames.OUTPUT_NAME); + } catch (Exception e) { + name = null; + } + if (name != null) { + registerOutputName(pipeline.getOutput(transform), name); + } + } + } + + List steps = job.getSteps(); + if (steps == null) { + steps = new LinkedList<>(); + job.setSteps(steps); + } + currentStep = step; + steps.add(step); + } + + @Override + public void addEncodingInput(Coder coder) { + CloudObject encoding = SerializableUtils.ensureSerializable(coder); + addObject(getProperties(), PropertyNames.ENCODING, encoding); + } + + @Override + public void addInput(String name, String value) { + addString(getProperties(), name, value); + } + + @Override + public void addInput(String name, Long value) { + addLong(getProperties(), name, value); + } + + @Override + public void addInput(String name, Map elements) { + addDictionary(getProperties(), name, elements); + } + + @Override + public void addInput(String name, List> elements) { + addList(getProperties(), name, elements); + } + + @Override + public void addInput(String name, PInput value) { + if (value instanceof PValue) { + addInput(name, asOutputReference((PValue) value)); + } else { + throw new IllegalStateException("Input must be a PValue"); + } + } + + @Override + public void addInput(String name, GcsPath path) { + addInput(name, DataflowPipelineRunner.verifyGcsPath(path).toResourceName()); + } + + @Override + public void addOutput(String name, PValue value) { + Coder coder; + if (value instanceof TypedPValue) { + coder = ((TypedPValue) value).getCoder(); + if (value instanceof PCollection) { + // Wrap the PCollection element Coder inside a WindowedValueCoder. + coder = WindowedValue.getFullCoder( + coder, + ((PCollection) value).getWindowingFn().windowCoder()); + } + } else { + // No output coder to encode. + coder = null; + } + addOutput(name, value, coder); + } + + @Override + public void addValueOnlyOutput(String name, PValue value) { + Coder coder; + if (value instanceof TypedPValue) { + coder = ((TypedPValue) value).getCoder(); + if (value instanceof PCollection) { + // Wrap the PCollection element Coder inside a ValueOnly + // WindowedValueCoder. + coder = WindowedValue.getValueOnlyCoder(coder); + } + } else { + // No output coder to encode. + coder = null; + } + addOutput(name, value, coder); + } + + @Override + public void addCollectionToSingletonOutput(String name, + PValue inputValue, + PValue outputValue) { + Coder inputValueCoder = + Preconditions.checkNotNull(outputCoders.get(inputValue)); + // The inputValueCoder for the input PCollection should be some + // WindowedValueCoder of the input PCollection's element + // coder. + Preconditions.checkState( + inputValueCoder instanceof WindowedValue.WindowedValueCoder); + // The outputValueCoder for the output should be an + // IterableCoder of the inputValueCoder. This is a property + // of the backend "CollectionToSingleton" step. + Coder outputValueCoder = IterableCoder.of(inputValueCoder); + addOutput(name, outputValue, outputValueCoder); + } + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue} + * with the given {@code Coder} (if not {@code null}). + */ + @SuppressWarnings("unchecked") + private void addOutput(String name, PValue value, Coder valueCoder) { + registerOutputName(value, name); + + Map properties = getProperties(); + @Nullable List> outputInfoList = null; + try { + // TODO: This should be done via a Structs accessor. + outputInfoList = (List>) properties.get(PropertyNames.OUTPUT_INFO); + } catch (Exception e) { + throw new RuntimeException("Inconsistent dataflow pipeline translation", e); + } + if (outputInfoList == null) { + outputInfoList = new ArrayList<>(); + // TODO: This should be done via a Structs accessor. + properties.put(PropertyNames.OUTPUT_INFO, outputInfoList); + } + + Map outputInfo = new HashMap<>(); + addString(outputInfo, PropertyNames.OUTPUT_NAME, name); + addString(outputInfo, PropertyNames.USER_NAME, value.getName()); + + if (valueCoder != null) { + // Verify that encoding can be decoded, in order to catch serialization + // failures as early as possible. + CloudObject encoding = SerializableUtils.ensureSerializable(valueCoder); + addObject(outputInfo, PropertyNames.ENCODING, encoding); + outputCoders.put(value, valueCoder); + } + + outputInfoList.add(outputInfo); + } + + @Override + public OutputReference asOutputReference(PValue value) { + PTransform transform = + value.getProducingTransformInternal(); + String stepName = stepNames.get(transform); + if (stepName == null) { + throw new IllegalArgumentException(transform + " doesn't have a name specified"); + } + + String outputName = outputNames.get(value); + if (outputName == null) { + throw new IllegalArgumentException( + "output " + value + " doesn't have a name specified"); + } + + return new OutputReference(stepName, outputName); + } + + private Map getProperties() { + Map properties = currentStep.getProperties(); + if (properties == null) { + properties = new HashMap<>(); + currentStep.setProperties(properties); + } + return properties; + } + + /** + * Returns a fresh Dataflow step name. + */ + private String genStepName() { + return "s" + (stepNames.size() + 1); + } + + /** + * Records the name of the given output PValue, + * within its producing transform. + */ + private void registerOutputName(POutput value, String name) { + if (outputNames.put(value, name) != null) { + throw new IllegalArgumentException( + "output " + value + " already has a name specified"); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public String toString() { + return "DataflowPipelineTranslator#" + hashCode(); + } + + + /////////////////////////////////////////////////////////////////////////// + + static { + registerTransformTranslator( + View.CreatePCollectionView.class, + new TransformTranslator() { + @Override + public void translate( + View.CreatePCollectionView transform, + TranslationContext context) { + translateTyped(transform, context); + } + + private void translateTyped( + View.CreatePCollectionView transform, + TranslationContext context) { + context.addStep(transform, "CollectionToSingleton"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + context.addCollectionToSingletonOutput( + PropertyNames.OUTPUT, + transform.getInput(), + transform.getOutput()); + } + }); + + DataflowPipelineTranslator.registerTransformTranslator( + Combine.GroupedValues.class, + new DataflowPipelineTranslator.TransformTranslator() { + @SuppressWarnings("unchecked") + @Override + public void translate( + Combine.GroupedValues transform, + DataflowPipelineTranslator.TranslationContext context) { + translateHelper(transform, context); + } + + private void translateHelper( + final Combine.GroupedValues transform, + DataflowPipelineTranslator.TranslationContext context) { + context.addStep(transform, "CombineValues"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + context.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray(transform.getFn()))); + context.addEncodingInput(transform.getAccumulatorCoder()); + context.addOutput(PropertyNames.OUTPUT, transform.getOutput()); + } + }); + + registerTransformTranslator( + Create.class, + new TransformTranslator() { + @Override + public void translate( + Create transform, + TranslationContext context) { + createHelper(transform, context); + } + + private void createHelper( + Create transform, + TranslationContext context) { + context.addStep(transform, "CreateCollection"); + + Coder coder = transform.getOutput().getCoder(); + List elements = new LinkedList<>(); + for (T elem : transform.getElements()) { + byte[] encodedBytes; + try { + encodedBytes = encodeToByteArray(coder, elem); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + "unable to encode element " + elem + " of " + transform + + " using " + coder, + exn); + } + String encodedJson = byteArrayToJsonString(encodedBytes); + assert Arrays.equals(encodedBytes, + jsonStringToByteArray(encodedJson)); + elements.add(CloudObject.forString(encodedJson)); + } + context.addInput(PropertyNames.ELEMENT, elements); + context.addValueOnlyOutput(PropertyNames.OUTPUT, transform.getOutput()); + } + }); + + registerTransformTranslator( + Flatten.FlattenPCollectionList.class, + new TransformTranslator() { + @Override + public void translate( + Flatten.FlattenPCollectionList transform, + TranslationContext context) { + flattenHelper(transform, context); + } + + private void flattenHelper( + Flatten.FlattenPCollectionList transform, + TranslationContext context) { + context.addStep(transform, "Flatten"); + + List inputs = new LinkedList<>(); + for (PCollection input : transform.getInput().getAll()) { + inputs.add(context.asOutputReference(input)); + } + context.addInput(PropertyNames.INPUTS, inputs); + context.addOutput(PropertyNames.OUTPUT, transform.getOutput()); + // TODO: Need to specify orderedness. + } + }); + + registerTransformTranslator( + GroupByKeyOnly.class, + new TransformTranslator() { + @Override + public void translate( + GroupByKeyOnly transform, + TranslationContext context) { + groupByKeyHelper(transform, context); + } + + private void groupByKeyHelper( + GroupByKeyOnly transform, + TranslationContext context) { + context.addStep(transform, "GroupByKey"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + context.addOutput(PropertyNames.OUTPUT, transform.getOutput()); + // TODO: sortsValues + } + }); + + registerTransformTranslator( + ParDo.BoundMulti.class, + new TransformTranslator() { + @Override + public void translate( + ParDo.BoundMulti transform, + TranslationContext context) { + translateMultiHelper(transform, context); + } + + private void translateMultiHelper( + ParDo.BoundMulti transform, + TranslationContext context) { + context.addStep(transform, "ParallelDo"); + translateInputs(transform.getInput(), transform.getSideInputs(), context); + translateFn(transform.getFn(), context); + translateOutputs(transform.getOutput(), context); + } + }); + + registerTransformTranslator( + ParDo.Bound.class, + new TransformTranslator() { + @Override + public void translate( + ParDo.Bound transform, + TranslationContext context) { + translateSingleHelper(transform, context); + } + + private void translateSingleHelper( + ParDo.Bound transform, + TranslationContext context) { + context.addStep(transform, "ParallelDo"); + translateInputs(transform.getInput(), transform.getSideInputs(), context); + translateFn(transform.getFn(), context); + context.addOutput("out", transform.getOutput()); + } + }); + + /////////////////////////////////////////////////////////////////////////// + // IO Translation. + + registerTransformTranslator( + AvroIO.Read.Bound.class, new AvroIOTranslator.ReadTranslator()); + registerTransformTranslator( + AvroIO.Write.Bound.class, new AvroIOTranslator.WriteTranslator()); + + registerTransformTranslator( + BigQueryIO.Read.Bound.class, new BigQueryIOTranslator.ReadTranslator()); + registerTransformTranslator( + BigQueryIO.Write.Bound.class, new BigQueryIOTranslator.WriteTranslator()); + + registerTransformTranslator( + DatastoreIO.Write.Bound.class, new DatastoreIOTranslator.WriteTranslator()); + + registerTransformTranslator( + PubsubIO.Read.Bound.class, new PubsubIOTranslator.ReadTranslator()); + registerTransformTranslator( + PubsubIO.Write.Bound.class, new PubsubIOTranslator.WriteTranslator()); + + registerTransformTranslator( + TextIO.Read.Bound.class, new TextIOTranslator.ReadTranslator()); + registerTransformTranslator( + TextIO.Write.Bound.class, new TextIOTranslator.WriteTranslator()); + } + + private static void translateInputs( + PCollection input, + List> sideInputs, + TranslationContext context) { + context.addInput(PropertyNames.PARALLEL_INPUT, input); + translateSideInputs(sideInputs, context); + } + + // Used for ParDo + private static void translateSideInputs( + List> sideInputs, + TranslationContext context) { + Map nonParInputs = new HashMap<>(); + + for (PCollectionView view : sideInputs) { + nonParInputs.put( + view.getTagInternal().getId(), + context.asOutputReference(view)); + } + + context.addInput(PropertyNames.NON_PARALLEL_INPUTS, nonParInputs); + } + + private static void translateFn( + Serializable fn, + TranslationContext context) { + context.addInput(PropertyNames.USER_FN, fn.getClass().getName()); + context.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray(fn))); + if (fn instanceof DoFn.RequiresKeyedState) { + context.addInput(PropertyNames.USES_KEYED_STATE, "true"); + } + } + + private static void translateOutputs( + PCollectionTuple outputs, + TranslationContext context) { + for (Map.Entry, PCollection> entry + : outputs.getAll().entrySet()) { + TupleTag tag = entry.getKey(); + PCollection output = entry.getValue(); + context.addOutput(tag.getId(), output); + // TODO: Need to specify orderedness. + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java new file mode 100644 index 000000000000..e3cd18ecfda3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; + +/** + * A DirectPipeline, which returns a + * {@link DirectPipelineRunner.EvaluationResults} subclass of PipelineResult + * from {@link com.google.cloud.dataflow.sdk.Pipeline#run()}. + */ +public class DirectPipeline extends Pipeline { + + /** + * Creates and returns a new DirectPipeline instance for tests. + */ + public static DirectPipeline createForTest() { + DirectPipelineRunner runner = DirectPipelineRunner.createForTest(); + return new DirectPipeline(runner, runner.getPipelineOptions()); + } + + private DirectPipeline(DirectPipelineRunner runner, DirectPipelineOptions options) { + super(runner, options); + } + + @Override + public DirectPipelineRunner.EvaluationResults run() { + return (DirectPipelineRunner.EvaluationResults) super.run(); + } + + @Override + public DirectPipelineRunner getRunner() { + return (DirectPipelineRunner) super.getRunner(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java new file mode 100644 index 000000000000..a19b2055a0b9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java @@ -0,0 +1,844 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import com.google.common.base.Function; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * Executes the operations in the pipeline directly, in this process, without + * any optimization. Useful for small local execution and tests. + * + *

Throws an exception from {@link #run} if execution fails. + */ +public class DirectPipelineRunner + extends PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(DirectPipelineRunner.class); + + /** + * A map from PTransform class to the corresponding + * TransformEvaluator to use to evaluate that transform. + * + *

A static map that contains system-wide defaults. + */ + private static Map defaultTransformEvaluators = + new HashMap<>(); + + /** + * A map from PTransform class to the corresponding + * TransformEvaluator to use to evaluate that transform. + * + *

An instance map that contains bindings for this DirectPipelineRunner. + * Bindings in this map override those in the default map. + */ + private Map localTransformEvaluators = + new HashMap<>(); + + /** + * Records that instances of the specified PTransform class + * should be evaluated by default by the corresponding + * TransformEvaluator. + */ + public static > + void registerDefaultTransformEvaluator( + Class transformClass, + TransformEvaluator transformEvaluator) { + if (defaultTransformEvaluators.put(transformClass, transformEvaluator) + != null) { + throw new IllegalArgumentException( + "defining multiple evaluators for " + transformClass); + } + } + + /** + * Records that instances of the specified PTransform class + * should be evaluated by the corresponding TransformEvaluator. + * Overrides any bindings specified by + * {@link #registerDefaultTransformEvaluator}. + */ + public > + void registerTransformEvaluator( + Class transformClass, + TransformEvaluator transformEvaluator) { + if (localTransformEvaluators.put(transformClass, transformEvaluator) + != null) { + throw new IllegalArgumentException( + "defining multiple evaluators for " + transformClass); + } + } + + /** + * Returns the TransformEvaluator to use for instances of the + * specified PTransform class, or null if none registered. + */ + @SuppressWarnings("unchecked") + public > + TransformEvaluator getTransformEvaluator(Class transformClass) { + TransformEvaluator transformEvaluator = + localTransformEvaluators.get(transformClass); + if (transformEvaluator == null) { + transformEvaluator = defaultTransformEvaluators.get(transformClass); + } + return transformEvaluator; + } + + /** + * Constructs a DirectPipelineRunner from the given options. + */ + public static DirectPipelineRunner fromOptions(PipelineOptions options) { + DirectPipelineOptions directOptions = + PipelineOptionsValidator.validate(DirectPipelineOptions.class, options); + LOG.debug("Creating DirectPipelineRunner"); + return new DirectPipelineRunner(directOptions); + } + + /** + * Constructs a runner with default properties for testing. + * + * @return The newly created runner. + */ + public static DirectPipelineRunner createForTest() { + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + return new DirectPipelineRunner(options); + } + + /** + * Enable runtime testing to verify that all functions and {@link Coder} + * instances can be serialized. + * + *

Enabled by default. + * + *

This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withSerializabilityTesting(boolean enable) { + this.testSerializability = enable; + return this; + } + + /** + * Enable runtime testing to verify that all values can be encoded. + * + *

Enabled by default. + * + *

This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withEncodabilityTesting(boolean enable) { + this.testEncodability = enable; + return this; + } + + /** + * Enable runtime testing to verify that functions do not depend on order + * of the elements. + * + *

This is accomplished by randomizing the order of elements. + * + *

Enabled by default. + * + *

This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withUnorderednessTesting(boolean enable) { + this.testUnorderedness = enable; + return this; + } + + @Override + @SuppressWarnings("unchecked") + public Output apply( + PTransform transform, Input input) { + if (transform instanceof Combine.GroupedValues) { + return (Output) applyTestCombine((Combine.GroupedValues) transform, (PCollection) input); + } else { + return super.apply(transform, input); + } + } + + private PCollection> applyTestCombine( + Combine.GroupedValues transform, + PCollection>> input) { + return input.apply(ParDo.of(TestCombineDoFn.create(transform, testSerializability))) + .setCoder(transform.getDefaultOutputCoder()); + } + + /** + * The implementation may split the KeyedCombineFn into ADD, MERGE + * and EXTRACT phases (see CombineValuesFn). In order to emulate + * this for the DirectPipelineRunner and provide an experience + * closer to the service, go through heavy seralizability checks for + * the equivalent of the results of the ADD phase, but after the + * GroupByKey shuffle, and the MERGE phase. Doing these checks + * ensure that not only is the accumulator coder serializable, but + * the accumulator coder can actually serialize the data in + * question. + */ + // @VisibleForTesting + public static class TestCombineDoFn + extends DoFn>, KV> { + private final KeyedCombineFn fn; + private final Coder accumCoder; + private final boolean testSerializability; + + @SuppressWarnings({"unchecked", "rawtypes"}) + public static TestCombineDoFn create( + Combine.GroupedValues transform, + boolean testSerializability) { + return new TestCombineDoFn( + transform.getFn(), transform.getAccumulatorCoder(), testSerializability); + } + + public TestCombineDoFn( + KeyedCombineFn fn, + Coder accumCoder, + boolean testSerializability) { + this.fn = fn; + this.accumCoder = accumCoder; + this.testSerializability = testSerializability; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + K key = c.element().getKey(); + Iterable values = c.element().getValue(); + List groupedPostShuffle = + ensureSerializableByCoder(ListCoder.of(accumCoder), + addInputsRandomly(fn, key, values, new Random()), + "After addInputs of KeyedCombineFn " + fn.toString()); + VA merged = + ensureSerializableByCoder(accumCoder, + fn.mergeAccumulators(key, groupedPostShuffle), + "After mergeAccumulators of KeyedCombineFn " + fn.toString()); + // Note: The serializability of KV is ensured by the + // runner itself, since it's a transform output. + c.output(KV.of(key, fn.extractOutput(key, merged))); + } + + // Create a random list of accumulators from the given list of values + // @VisibleForTesting + public static List addInputsRandomly( + KeyedCombineFn fn, + K key, + Iterable values, + Random random) { + List out = new ArrayList(); + int i = 0; + VA accumulator = fn.createAccumulator(key); + boolean hasInput = false; + + for (VI value : values) { + fn.addInput(key, accumulator, value); + hasInput = true; + + // For each index i, flip a 1/2^i weighted coin for whether to + // create a new accumulator after index i is added, i.e. [0] + // is guaranteed, [1] is an even 1/2, [2] is 1/4, etc. The + // goal is to partition the inputs into accumulators, and make + // the accumulators potentially lumpy. + if (i == 0 || random.nextInt(1 << Math.min(i, 30)) == 0) { + out.add(accumulator); + accumulator = fn.createAccumulator(key); + hasInput = false; + } + i++; + } + if (hasInput) { + out.add(accumulator); + } + + Collections.shuffle(out, random); + return out; + } + + public T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + if (testSerializability) { + return SerializableUtils.ensureSerializableByCoder( + coder, value, errorContext); + } + return value; + } + } + + @Override + public EvaluationResults run(Pipeline pipeline) { + Evaluator evaluator = new Evaluator(); + evaluator.run(pipeline); + + // Log all counter values for debugging purposes. + for (Counter counter : evaluator.getCounters()) { + LOG.debug("Final aggregator value: {}", counter); + } + + return evaluator; + } + + /** + * An evaluator of a PTransform. + */ + public interface TransformEvaluator { + public void evaluate(PT transform, + EvaluationContext context); + } + + /** + * The interface provided to registered callbacks for interacting + * with the {@code DirectPipelineRunner}, including reading and writing the + * values of {@link PCollection}s and {@link PCollectionView}s. + */ + public interface EvaluationResults extends PipelineResult { + /** + * Retrieves the value of the given PCollection. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List getPCollection(PCollection pc); + + /** + * Retrieves the windowed value of the given PCollection. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List> getPCollectionWindowedValues(PCollection pc); + + /** + * Retrieves the values of each PCollection in the given + * PCollectionList. Throws an exception if the PCollectionList's + * value hasn't already been set. + */ + List> getPCollectionList(PCollectionList pcs); + + /** + * Retrieves the values indicated by the given {@link PCollectionView}. + * Note that within the {@link DoFnContext} a {@link PCollectionView} + * converts from this representation to a suitable side input value. + */ + Iterable> getPCollectionView(PCollectionView view); + } + + /** + * An immutable (value, timestamp) pair, along with other metadata necessary + * for the implementation of {@code DirectPipelineRunner}. + */ + public static class ValueWithMetadata { + /** + * Returns a new {@code ValueWithMetadata} with the {@code WindowedValue}. + * Key is null. + */ + public static ValueWithMetadata of(WindowedValue windowedValue) { + return new ValueWithMetadata<>(windowedValue, null); + } + + /** + * Returns a new {@code ValueWithMetadata} with the implicit key associated + * with this value set. The key is the last key grouped by in the chain of + * productions that produced this element. + * These keys are used internally by {@link DirectPipelineRunner} for keeping + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.KeyedState} separate + * across keys. + */ + public ValueWithMetadata withKey(Object key) { + return new ValueWithMetadata<>(windowedValue, key); + } + + /** + * Returns a new {@code ValueWithMetadata} that is a copy of this one, but with + * a different value. + */ + public ValueWithMetadata withValue(T value) { + return new ValueWithMetadata(windowedValue.withValue(value), getKey()); + } + + /** + * Returns the {@code WindowedValue} associated with this element. + */ + public WindowedValue getWindowedValue() { + return windowedValue; + } + + /** + * Returns the value associated with this element. + * + * @see #withValue + */ + public V getValue() { + return windowedValue.getValue(); + } + + /** + * Returns the timestamp associated with this element. + */ + public Instant getTimestamp() { + return windowedValue.getTimestamp(); + } + + /** + * Returns the collection of windows this element has been placed into. May + * be null if the {@code PCollection} this element is in has not yet been + * windowed. + * + * @see #getWindows() + */ + public Collection getWindows() { + return windowedValue.getWindows(); + } + + + /** + * Returns the key associated with this element. May be null if the + * {@code PCollection} this element is in is not keyed. + * + * @see #withKey + */ + public Object getKey() { + return key; + } + + //////////////////////////////////////////////////////////////////////////// + + private final Object key; + private final WindowedValue windowedValue; + + private ValueWithMetadata(WindowedValue windowedValue, + Object key) { + this.windowedValue = windowedValue; + this.key = key; + } + } + + /** + * The interface provided to registered callbacks for interacting + * with the {@code DirectPipelineRunner}, including reading and writing the + * values of {@link PCollection}s and {@link PCollectionView}s. + */ + public interface EvaluationContext extends EvaluationResults { + /** + * Returns the configured pipeline options. + */ + DirectPipelineOptions getPipelineOptions(); + + /** + * Sets the value of the given PCollection, where each element also has a timestamp + * and collection of windows. + * Throws an exception if the PCollection's value has already been set. + */ + void setPCollectionValuesWithMetadata( + PCollection pc, List> elements); + + /** + * Shorthand for setting the value of a PCollection where the elements do not have + * timestamps or windows. + * Throws an exception if the PCollection's value has already been set. + */ + void setPCollection(PCollection pc, List elements); + + /** + * Retrieves the value of the given PCollection, along with element metadata + * such as timestamps and windows. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List> getPCollectionValuesWithMetadata(PCollection pc); + + /** + * Sets the value associated with the given {@link PCollectionView}. + * Throws an exception if the {@link PCollectionView}'s value has already been set. + */ + void setPCollectionView( + PCollectionView pc, + Iterable> value); + + /** + * Ensures that the element is encodable and decodable using the + * TypePValue's coder, by encoding it and decoding it, and + * returning the result. + */ + T ensureElementEncodable(TypedPValue pvalue, T element); + + /** + * If the evaluation context is testing unorderedness and + * !isOrdered, randomly permutes the order of the elements, in a + * copy if !inPlaceAllowed, and returns the permuted list, + * otherwise returns the argument unchanged. + */ + List randomizeIfUnordered(boolean isOrdered, + List elements, + boolean inPlaceAllowed); + + /** + * If the evaluation context is testing serializability, ensures + * that the argument function is serializable and deserializable + * by encoding it and then decoding it, and returning the result. + * Otherwise returns the argument unchanged. + */ + Fn ensureSerializable(Fn fn); + + /** + * If the evaluation context is testing serializability, ensures + * that the argument Coder is serializable and deserializable + * by encoding it and then decoding it, and returning the result. + * Otherwise returns the argument unchanged. + */ + Coder ensureCoderSerializable(Coder coder); + + /** + * If the evaluation context is testing serializability, ensures + * that the given data is serializable and deserializable with the + * given Coder by encoding it and then decoding it, and returning + * the result. Otherwise returns the argument unchanged. + * + *

Error context is prefixed to any thrown exceptions. + */ + T ensureSerializableByCoder(Coder coder, + T data, String errorContext); + + /** + * Returns a mutator, which can be used to add additional counters to + * this EvaluationContext. + */ + CounterSet.AddCounterMutator getAddCounterMutator(); + + /** + * Gets the step name for this transform. + */ + public String getStepName(PTransform transform); + } + + + ///////////////////////////////////////////////////////////////////////////// + + class Evaluator implements PipelineVisitor, EvaluationContext { + private final Map stepNames = new HashMap<>(); + private final Map store = new HashMap<>(); + private final CounterSet counters = new CounterSet(); + + // Use a random number generator with a fixed seed, so execution + // using this evaluator is deterministic. (If the user-defined + // functions, transforms, and coders are deterministic.) + Random rand = new Random(0); + + public Evaluator() {} + + public void run(Pipeline pipeline) { + pipeline.traverseTopologically(this); + } + + @Override + public DirectPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @SuppressWarnings("unchecked") + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + TransformEvaluator evaluator = + getTransformEvaluator(transform.getClass()); + if (evaluator == null) { + throw new IllegalStateException( + "no evaluator registered for " + transform); + } + LOG.debug("Evaluating {}", transform); + evaluator.evaluate(transform, this); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + LOG.debug("Checking evaluation of {}", value); + if (value.getProducingTransformInternal() == null) { + throw new RuntimeException( + "internal error: expecting a PValue " + + "to have a producingTransform"); + } + if (!producer.isCompositeNode()) { + // Verify that primitive transform outputs are already computed. + getPValue(value); + } + } + + /** + * Sets the value of the given PValue. + * Throws an exception if the PValue's value has already been set. + */ + void setPValue(PValue pvalue, Object contents) { + if (store.containsKey(pvalue)) { + throw new IllegalStateException( + "internal error: setting the value of " + pvalue + + " more than once"); + } + store.put(pvalue, contents); + } + + /** + * Retrieves the value of the given PValue. + * Throws an exception if the PValue's value hasn't already been set. + */ + Object getPValue(PValue pvalue) { + if (!store.containsKey(pvalue)) { + throw new IllegalStateException( + "internal error: getting the value of " + pvalue + + " before it has been computed"); + } + return store.get(pvalue); + } + + /** + * Convert a list of T to a list of {@code ValueWithMetadata}, with a timestamp of 0 + * and null windows. + */ + List> toValuesWithMetadata(List values) { + List> result = new ArrayList<>(values.size()); + for (T value : values) { + result.add(ValueWithMetadata.of(WindowedValue.valueInGlobalWindow(value))); + } + return result; + } + + @Override + public void setPCollection(PCollection pc, List elements) { + setPCollectionValuesWithMetadata(pc, toValuesWithMetadata(elements)); + } + + @Override + public void setPCollectionValuesWithMetadata( + PCollection pc, List> elements) { + LOG.debug("Setting {} = {}", pc, elements); + setPValue(pc, ensurePCollectionEncodable(pc, elements)); + } + + @Override + public void setPCollectionView( + PCollectionView view, + Iterable> value) { + LOG.debug("Setting {} = {}", view, value); + setPValue(view, value); + } + + /** + * Retrieves the value of the given PCollection. + * Throws an exception if the PCollection's value hasn't already been set. + */ + @Override + public List getPCollection(PCollection pc) { + List result = new ArrayList<>(); + for (ValueWithMetadata elem : getPCollectionValuesWithMetadata(pc)) { + result.add(elem.getValue()); + } + return result; + } + + @Override + public List> getPCollectionWindowedValues(PCollection pc) { + return Lists.transform( + getPCollectionValuesWithMetadata(pc), + new Function, WindowedValue>() { + @Override + public WindowedValue apply(ValueWithMetadata input) { + return input.getWindowedValue(); + }}); + } + + @Override + public List> getPCollectionValuesWithMetadata(PCollection pc) { + @SuppressWarnings("unchecked") + List> elements = (List>) getPValue(pc); + elements = randomizeIfUnordered( + pc.isOrdered(), elements, false /* not inPlaceAllowed */); + LOG.debug("Getting {} = {}", pc, elements); + return elements; + } + + @Override + public List> getPCollectionList(PCollectionList pcs) { + List> elementsList = new ArrayList<>(); + for (PCollection pc : pcs.getAll()) { + elementsList.add(getPCollection(pc)); + } + return elementsList; + } + + /** + * Retrieves the value indicated by the given {@link PCollectionView}. + * Note that within the {@link DoFnContext} a {@link PCollectionView} + * converts from this representation to a suitable side input value. + */ + @Override + public Iterable> getPCollectionView(PCollectionView view) { + @SuppressWarnings("unchecked") + Iterable> value = (Iterable>) getPValue(view); + LOG.debug("Getting {} = {}", view, value); + return value; + } + + /** + * If testEncodability, ensures that the PCollection's coder and elements + * are encodable and decodable by encoding them and decoding them, + * and returning the result. Otherwise returns the argument elements. + */ + List> ensurePCollectionEncodable( + PCollection pc, List> elements) { + ensureCoderSerializable(pc.getCoder()); + if (!testEncodability) { + return elements; + } + List> elementsCopy = new ArrayList<>(elements.size()); + for (ValueWithMetadata element : elements) { + elementsCopy.add( + element.withValue(ensureElementEncodable(pc, element.getValue()))); + } + return elementsCopy; + } + + @Override + public T ensureElementEncodable(TypedPValue pvalue, T element) { + return ensureSerializableByCoder( + pvalue.getCoder(), element, "Within " + pvalue.toString()); + } + + @Override + public List randomizeIfUnordered(boolean isOrdered, + List elements, + boolean inPlaceAllowed) { + if (!testUnorderedness || isOrdered) { + return elements; + } + List elementsCopy = new ArrayList<>(elements); + Collections.shuffle(elementsCopy, rand); + return elementsCopy; + } + + @Override + public Fn ensureSerializable(Fn fn) { + if (!testSerializability) { + return fn; + } + return SerializableUtils.ensureSerializable(fn); + } + + @Override + public Coder ensureCoderSerializable(Coder coder) { + if (testSerializability) { + SerializableUtils.ensureSerializable(coder); + } + return coder; + } + + @Override + public T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + if (testSerializability) { + return SerializableUtils.ensureSerializableByCoder( + coder, value, errorContext); + } + return value; + } + + @Override + public CounterSet.AddCounterMutator getAddCounterMutator() { + return counters.getAddCounterMutator(); + } + + @Override + public String getStepName(PTransform transform) { + String stepName = stepNames.get(transform); + if (stepName == null) { + stepName = "s" + (stepNames.size() + 1); + stepNames.put(transform, stepName); + } + return stepName; + } + + /** + * Returns the CounterSet generated during evaluation, which includes + * user-defined Aggregators and may include system-defined counters. + */ + public CounterSet getCounters() { + return counters; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + private final DirectPipelineOptions options; + private boolean testSerializability = true; + private boolean testEncodability = true; + private boolean testUnorderedness = true; + + /** Returns a new DirectPipelineRunner. */ + private DirectPipelineRunner(DirectPipelineOptions options) { + this.options = options; + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(options); + } + + public DirectPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public String toString() { return "DirectPipelineRunner#" + hashCode(); } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java new file mode 100644 index 000000000000..8b134e98601c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.Preconditions; + +/** + * A PipelineRunner can execute, translate, or otherwise process a + * Pipeline. + * + * @param the type of the result of {@link #run}. + */ +public abstract class PipelineRunner { + + /** + * Constructs a runner from the provided options. + * + * @return The newly created runner. + */ + public static PipelineRunner fromOptions(PipelineOptions options) { + GcsOptions gcsOptions = PipelineOptionsValidator.validate(GcsOptions.class, options); + Preconditions.checkNotNull(options); + + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(gcsOptions); + + @SuppressWarnings("unchecked") + PipelineRunner result = + InstanceBuilder.ofType(PipelineRunner.class) + .fromClass(options.getRunner()) + .fromFactoryMethod("fromOptions") + .withArg(PipelineOptions.class, options) + .build(); + return result; + } + + /** + * Processes the given Pipeline, returning the results. + */ + public abstract Results run(Pipeline pipeline); + + /** + * Applies a transform to the given input, returning the output. + * + *

The default implementation calls PTransform.apply(input), but can be overridden + * to customize behavior for a particular runner. + */ + public Output apply( + PTransform transform, Input input) { + return transform.apply(input); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java new file mode 100644 index 000000000000..cb1850d654bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.List; + +/** + * Provides a simple PipelineVisitor which records the transformation tree. + * + *

Provided for internal unit tests. + */ +public class RecordingPipelineVisitor implements Pipeline.PipelineVisitor { + + public final List> transforms = new ArrayList<>(); + public final List values = new ArrayList<>(); + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + transforms.add(node.getTransform()); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + values.add(value); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java new file mode 100644 index 000000000000..53a90b2b8012 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; + +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; + +/** + * Captures information about a collection of transformations and their + * associated PValues. + */ +public class TransformHierarchy { + private final Deque transformStack = new LinkedList<>(); + private final Map producingTransformNode = new HashMap<>(); + private final Map, TransformTreeNode> transformToNode = new HashMap<>(); + + public TransformHierarchy() { + // First element in the stack is the root node, holding all child nodes. + transformStack.add(new TransformTreeNode(null, null, "", null)); + } + + /** + * Returns the last TransformTreeNode on the stack. + */ + public TransformTreeNode getCurrent() { + return transformStack.peek(); + } + + /** + * Add a TransformTreeNode to the stack. + */ + public void pushNode(TransformTreeNode current) { + transformStack.push(current); + transformToNode.put(current.getTransform(), current); + } + + /** + * Removes the last TransformTreeNode from the stack. + */ + public void popNode() { + transformStack.pop(); + Preconditions.checkState(!transformStack.isEmpty()); + } + + /** + * Adds an input to the given node. + * + *

This forces the producing node to be finished. + */ + public void addInput(TransformTreeNode node, PInput input) { + for (PValue i : input.expand()) { + TransformTreeNode producer = producingTransformNode.get(i); + if (producer == null) { + throw new IllegalStateException("Producer unknown for input: " + i); + } + + producer.finishSpecifying(); + node.addInputProducer(i, producer); + } + } + + /** + * Sets the output of a transform node. + */ + public void setOutput(TransformTreeNode producer, POutput output) { + producer.setOutput(output); + + for (PValue o : output.expand()) { + producingTransformNode.put(o, producer); + } + } + + /** + * Returns the TransformTreeNode associated with a given transform. + */ + public TransformTreeNode getNode(PTransform transform) { + return transformToNode.get(transform); + } + + /** + * Visits all nodes in the transform hierarchy, in transitive order. + */ + public void visit(Pipeline.PipelineVisitor visitor, + Set visitedNodes) { + transformStack.peekFirst().visit(visitor, visitedNodes); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java new file mode 100644 index 000000000000..efd28b354f07 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Provides internal tracking of transform relationships with helper methods + * for initialization and ordered visitation. + */ +public class TransformTreeNode { + private final TransformTreeNode enclosingNode; + + // The transform. If composite.isEmpty(), then this is a + // PrimitivePTransform, otherwise a composite PTransform. + private final PTransform transform; + + private final String fullName; + + // Nodes of a composite transform. + private final Collection parts = new ArrayList<>(); + + // Inputs to the transform, in expanded form and mapped to the producer + // of the input. + private final Map inputs = new HashMap<>(); + + // Input to the transform, in unexpanded form. + private final PInput input; + + // TODO: track which outputs need to be exported to parent. + // Output of the transform, in unexpanded form. + private POutput output; + + private boolean finishedSpecifying = false; + + /** + * Creates a new TransformTreeNode with the given parent and transform. + * + *

EnclosingNode and transform may both be null for a root-level node + * which holds all other nodes. + * + * @param enclosingNode the composite node containing this node + * @param transform the PTransform tracked by this node + * @param fullName the fully qualified name of the transform + * @param input the unexpanded input to the transform + */ + public TransformTreeNode(@Nullable TransformTreeNode enclosingNode, + @Nullable PTransform transform, + String fullName, + @Nullable PInput input) { + this.enclosingNode = enclosingNode; + this.transform = transform; + Preconditions.checkArgument((enclosingNode == null && transform == null) + || (enclosingNode != null && transform != null), + "EnclosingNode and transform must both be specified, or both be null"); + this.fullName = fullName; + this.input = input; + } + + /** + * Returns the transform associated with this transform node. + */ + public PTransform getTransform() { + return transform; + } + + /** + * Returns the enclosing composite transform node, or null if there is none. + */ + public TransformTreeNode getEnclosingNode() { + return enclosingNode; + } + + /** + * Adds a composite operation to the transform node. + * + *

As soon as a node is added, the transform node is considered a + * composite operation instead of a primitive transform. + */ + public void addComposite(TransformTreeNode node) { + parts.add(node); + } + + /** + * Returns true if this node represents a composite transform. + */ + public boolean isCompositeNode() { + return !parts.isEmpty(); + } + + public String getFullName() { + return fullName; + } + + /** + * Adds an input to the transform node. + */ + public void addInputProducer(PValue expandedInput, TransformTreeNode producer) { + Preconditions.checkState(!finishedSpecifying); + inputs.put(expandedInput, producer); + } + + /** + * Returns the transform input, in unexpanded form. + */ + public PInput getInput() { + return input; + } + + /** + * Returns a mapping of inputs to the producing nodes for all inputs to + * the transform. + */ + public Map getInputs() { + return Collections.unmodifiableMap(inputs); + } + + /** + * Adds an output to the transform node. + */ + public void setOutput(POutput output) { + Preconditions.checkState(!finishedSpecifying); + Preconditions.checkState(this.output == null); + this.output = output; + } + + /** + * Returns the transform output, in unexpanded form. + */ + public POutput getOutput() { + return output; + } + + /** + * Returns the transform outputs, in expanded form. + */ + public Collection getExpandedOutputs() { + if (output != null) { + return output.expand(); + } else { + return Collections.emptyList(); + } + } + + /** + * Visit the transform node. + * + *

Provides an ordered visit of the input values, the primitive + * transform (or child nodes for composite transforms), then the + * output values. + */ + public void visit(Pipeline.PipelineVisitor visitor, + Set visitedValues) { + if (!finishedSpecifying) { + finishSpecifying(); + } + + // Visit inputs. + for (Map.Entry entry : inputs.entrySet()) { + if (visitedValues.add(entry.getKey())) { + visitor.visitValue(entry.getKey(), entry.getValue()); + } + } + + if (isCompositeNode()) { + visitor.enterCompositeTransform(this); + for (TransformTreeNode child : parts) { + child.visit(visitor, visitedValues); + } + visitor.leaveCompositeTransform(this); + } else { + visitor.visitTransform(this); + } + + // Visit outputs. + for (PValue pValue : getExpandedOutputs()) { + if (visitedValues.add(pValue)) { + visitor.visitValue(pValue, this); + } + } + } + + /** + * Finish specifying a transform. + * + *

All inputs are finished first, then the transform, then + * all outputs. + */ + public void finishSpecifying() { + if (finishedSpecifying) { + return; + } + finishedSpecifying = true; + + for (TransformTreeNode input : inputs.values()) { + if (input != null) { + input.finishSpecifying(); + } + } + + if (transform != null) { + transform.finishSpecifying(); + } + + if (output != null) { + output.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AvroIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AvroIOTranslator.java new file mode 100644 index 000000000000..d7e36c54fc05 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AvroIOTranslator.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +/** + * Avro transform support code for the Dataflow backend. + */ +public class AvroIOTranslator { + + /** + * Implements AvroIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator implements TransformTranslator { + + @Override + public void translate( + AvroIO.Read.Bound transform, + TranslationContext context) { + translateReadHelper(transform, context); + } + + private void translateReadHelper( + AvroIO.Read.Bound transform, + TranslationContext context) { + if (context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException("AvroIO not supported in streaming mode."); + } + + // Only GCS paths are permitted for filepatterns in the DataflowPipelineRunner. + GcsPath gcsPath = GcsPath.fromUri(transform.getFilepattern()); + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "avro"); + context.addInput(PropertyNames.FILEPATTERN, gcsPath); + context.addValueOnlyOutput(PropertyNames.OUTPUT, transform.getOutput()); + // TODO: Orderedness? + } + } + + /** + * Implements AvroIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator implements TransformTranslator { + + @Override + public void translate( + AvroIO.Write.Bound transform, + TranslationContext context) { + translateWriteHelper(transform, context); + } + + private void translateWriteHelper( + AvroIO.Write.Bound transform, + TranslationContext context) { + // Only GCS paths are permitted for filepatterns in the DataflowPipelineRunner. + GcsPath gcsPath = GcsPath.fromUri(transform.getFilenamePrefix()); + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + + // TODO: drop this check when server supports alternative templates. + switch (transform.getShardTemplate()) { + case ShardNameTemplate.INDEX_OF_MAX: + break; // supported by server + case "": + // Empty shard template allowed - forces single output. + Preconditions.checkArgument(transform.getNumShards() <= 1, + "Num shards must be <= 1 when using an empty sharding template"); + break; + default: + throw new UnsupportedOperationException("Shard template " + + transform.getShardTemplate() + + " not yet supported by Dataflow service"); + } + + context.addInput(PropertyNames.FORMAT, "avro"); + context.addInput(PropertyNames.FILENAME_PREFIX, gcsPath); + context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, transform.getShardTemplate()); + context.addInput(PropertyNames.FILENAME_SUFFIX, transform.getFilenameSuffix()); + + long numShards = transform.getNumShards(); + if (numShards > 0) { + context.addInput(PropertyNames.NUM_SHARDS, numShards); + } + + context.addEncodingInput( + WindowedValue.getValueOnlyCoder( + AvroCoder.of(transform.getType(), transform.getSchema()))); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java new file mode 100644 index 000000000000..fd2731949c41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.api.client.json.JsonFactory; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.TableReference; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator; +import com.google.cloud.dataflow.sdk.util.ApiErrorExtractor; +import com.google.cloud.dataflow.sdk.util.BigQueryTableInserter; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +import java.io.IOException; + +/** + * BigQuery transform support code for the Dataflow backend. + */ +public class BigQueryIOTranslator { + private static final JsonFactory JSON_FACTORY = Transport.getJsonFactory(); + + /** + * Implements BigQueryIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator + implements DataflowPipelineTranslator.TransformTranslator { + + @Override + public void translate(BigQueryIO.Read.Bound transform, + DataflowPipelineTranslator.TranslationContext context) { + TableReference table = transform.getTable(); + if (table.getProjectId() == null) { + table.setProjectId(context.getPipelineOptions().getProject()); + } + + // Check for source table presence for early failure notification. + // Note that a presence check can fail if the table or dataset are created by earlier stages + // of the pipeline. For these cases the withoutValidation method can be used to disable + // the check. + if (transform.getValidate()) { + verifyDatasetPresence(context.getPipelineOptions(), table); + verifyTablePresence(context.getPipelineOptions(), table); + } + + // Actual translation. + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "bigquery"); + context.addInput(PropertyNames.BIGQUERY_TABLE, table.getTableId()); + context.addInput(PropertyNames.BIGQUERY_DATASET, table.getDatasetId()); + if (table.getProjectId() != null) { + context.addInput(PropertyNames.BIGQUERY_PROJECT, table.getProjectId()); + } + context.addValueOnlyOutput(PropertyNames.OUTPUT, transform.getOutput()); + } + } + + /** + * Implements BigQueryIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator + implements DataflowPipelineTranslator.TransformTranslator { + + @Override + public void translate(BigQueryIO.Write.Bound transform, + DataflowPipelineTranslator.TranslationContext context) { + if (context.getPipelineOptions().isStreaming()) { + // Streaming is handled by the streaming runner. + throw new AssertionError( + "BigQueryIO is specified to use streaming write in batch mode."); + } + + TableReference table = transform.getTable(); + if (table.getProjectId() == null) { + table.setProjectId(context.getPipelineOptions().getProject()); + } + + // Check for destination table presence and emptiness for early failure notification. + // Note that a presence check can fail if the table or dataset are created by earlier stages + // of the pipeline. For these cases the withoutValidation method can be used to disable + // the check. + if (transform.getValidate()) { + verifyDatasetPresence(context.getPipelineOptions(), table); + if (transform.getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) { + verifyTablePresence(context.getPipelineOptions(), table); + } + if (transform.getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) { + verifyTableEmpty(context.getPipelineOptions(), table); + } + } + + // Actual translation. + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.FORMAT, "bigquery"); + context.addInput(PropertyNames.BIGQUERY_TABLE, + table.getTableId()); + context.addInput(PropertyNames.BIGQUERY_DATASET, + table.getDatasetId()); + if (table.getProjectId() != null) { + context.addInput(PropertyNames.BIGQUERY_PROJECT, table.getProjectId()); + } + if (transform.getSchema() != null) { + try { + context.addInput(PropertyNames.BIGQUERY_SCHEMA, + JSON_FACTORY.toString(transform.getSchema())); + } catch (IOException exn) { + throw new IllegalArgumentException("Invalid table schema.", exn); + } + } + context.addInput( + PropertyNames.BIGQUERY_CREATE_DISPOSITION, + transform.getCreateDisposition().name()); + context.addInput( + PropertyNames.BIGQUERY_WRITE_DISPOSITION, + transform.getWriteDisposition().name()); + // Set sink encoding to TableRowJsonCoder. + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(TableRowJsonCoder.of())); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private static void verifyDatasetPresence( + BigQueryOptions options, + TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + client.datasets().get(table.getProjectId(), table.getDatasetId()) + .execute(); + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (errorExtractor.itemNotFound(e)) { + throw new IllegalArgumentException( + "BigQuery dataset not found for table: " + BigQueryIO.toTableSpec(table), e); + } else { + throw new RuntimeException( + "unable to confirm BigQuery dataset presence", e); + } + } + } + + private static void verifyTablePresence( + BigQueryOptions options, + TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + client.tables().get(table.getProjectId(), table.getDatasetId(), table.getTableId()) + .execute(); + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (errorExtractor.itemNotFound(e)) { + throw new IllegalArgumentException( + "BigQuery table not found: " + BigQueryIO.toTableSpec(table), e); + } else { + throw new RuntimeException( + "unable to confirm BigQuery table presence", e); + } + } + } + + private static void verifyTableEmpty( + BigQueryOptions options, + TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableInserter inserter = new BigQueryTableInserter(client, table); + if (!inserter.isEmpty()) { + throw new IllegalArgumentException( + "BigQuery table is not empty: " + BigQueryIO.toTableSpec(table)); + } + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (errorExtractor.itemNotFound(e)) { + // Nothing to do. If the table does not exist, it is considered empty. + } else { + throw new RuntimeException( + "unable to confirm BigQuery table emptiness", e); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DatastoreIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DatastoreIOTranslator.java new file mode 100644 index 000000000000..4292199174a1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DatastoreIOTranslator.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.cloud.dataflow.sdk.io.DatastoreIO; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; + +/** + * Datastore transform support code for the Dataflow backend. + */ +public class DatastoreIOTranslator { + + /** + * Implements DatastoreIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator implements TransformTranslator { + @Override + public void translate( + DatastoreIO.Write.Bound transform, + TranslationContext context) { + // TODO: Not implemented yet. + // translateWriteHelper(transform, context); + throw new UnsupportedOperationException("Write only supports direct mode now."); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java new file mode 100644 index 000000000000..706397bddd37 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +/** + * Pubsub transform support code for the Dataflow backend. + */ +public class PubsubIOTranslator { + + /** + * Implements PubsubIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator implements TransformTranslator { + @Override + public void translate( + PubsubIO.Read.Bound transform, + TranslationContext context) { + translateReadHelper(transform, context); + } + + /* + private static void translateReadHelper( + */ + + private void translateReadHelper( + PubsubIO.Read.Bound transform, + TranslationContext context) { + if (!context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException("PubsubIO can only be used in streaming mode."); + } + + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + if (transform.getTopic() != null) { + context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic()); + } + if (transform.getSubscription() != null) { + context.addInput(PropertyNames.PUBSUB_SUBSCRIPTION, transform.getSubscription()); + } + context.addValueOnlyOutput(PropertyNames.OUTPUT, transform.getOutput()); + // TODO: Orderedness? + } + } + + /** + * Implements PubsubIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator implements TransformTranslator { + @Override + public void translate( + PubsubIO.Write.Bound transform, + TranslationContext context) { + translateWriteHelper(transform, context); + } + + private void translateWriteHelper( + PubsubIO.Write.Bound transform, + TranslationContext context) { + if (!context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException("PubsubIO can only be used in streaming mode."); + } + + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic()); + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(transform.getInput().getCoder())); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java new file mode 100644 index 000000000000..05a44648eba9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +/** + * TextIO transform support code for the Dataflow backend. + */ +public class TextIOTranslator { + + /** + * Implements TextIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator implements TransformTranslator { + @Override + public void translate( + TextIO.Read.Bound transform, + TranslationContext context) { + translateReadHelper(transform, context); + } + + private void translateReadHelper( + TextIO.Read.Bound transform, + TranslationContext context) { + if (context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException("TextIO not supported in streaming mode."); + } + + // Only GCS paths are permitted for filepatterns in the DataflowPipelineRunner. + GcsPath gcsPath = GcsPath.fromUri(transform.getFilepattern()); + // Furthermore, on the service there is currently a limitation + // that the first wildcard character must occur after the last + // delimiter, and that the delimiter is fixed to '/' + if (!GcsUtil.GCS_READ_PATTERN.matcher(gcsPath.getObject()).matches()) { + throw new IllegalArgumentException( + "Unsupported wildcard usage in \"" + gcsPath + "\": " + + " all wildcards must occur after the final '/' delimiter."); + } + + context.addStep(transform, "ParallelRead"); + // TODO: How do we want to specify format and + // format-specific properties? + context.addInput(PropertyNames.FORMAT, "text"); + context.addInput(PropertyNames.FILEPATTERN, gcsPath); + context.addValueOnlyOutput(PropertyNames.OUTPUT, transform.getOutput()); + + // TODO: Orderedness? + } + } + + /** + * Implements TextIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator implements TransformTranslator { + @Override + public void translate( + TextIO.Write.Bound transform, + TranslationContext context) { + translateWriteHelper(transform, context); + } + + private void translateWriteHelper( + TextIO.Write.Bound transform, + TranslationContext context) { + if (context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException("TextIO not supported in streaming mode."); + } + + // Only GCS paths are permitted for filepatterns in the DataflowPipelineRunner. + GcsPath gcsPath = GcsPath.fromUri(transform.getFilenamePrefix()); + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + + // TODO: drop this check when server supports alternative templates. + switch (transform.getShardTemplate()) { + case ShardNameTemplate.INDEX_OF_MAX: + break; // supported by server + case "": + // Empty shard template allowed - forces single output. + Preconditions.checkArgument(transform.getNumShards() <= 1, + "Num shards must be <= 1 when using an empty sharding template"); + break; + default: + throw new UnsupportedOperationException("Shard template " + + transform.getShardTemplate() + + " not yet supported by Dataflow service"); + } + + // TODO: How do we want to specify format and + // format-specific properties? + context.addInput(PropertyNames.FORMAT, "text"); + context.addInput(PropertyNames.FILENAME_PREFIX, gcsPath); + context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, + transform.getShardNameTemplate()); + context.addInput(PropertyNames.FILENAME_SUFFIX, transform.getFilenameSuffix()); + + long numShards = transform.getNumShards(); + if (numShards > 0) { + context.addInput(PropertyNames.NUM_SHARDS, numShards); + } + + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(transform.getCoder())); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java new file mode 100644 index 000000000000..c2fcc288cf3c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Implementation of the {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}. + */ +package com.google.cloud.dataflow.sdk.runners.dataflow; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java new file mode 100644 index 000000000000..c75fe2f8348e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines runners for executing Pipelines in different modes, including + * {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner} and + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}. + * + *

{@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner} executes a {@code Pipeline} + * locally, without contacting the Dataflow service. + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner} submits a + * {@code Pipeline} to the Dataflow service, which executes it on Dataflow-managed Compute Engine + * instances. {@code DataflowPipelineRunner} returns + * as soon as the {@code Pipeline} has been submitted. Use + * {@link com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner} to have execution + * updates printed to the console. + * + *

The runner is specified as part {@link com.google.cloud.dataflow.sdk.options.PipelineOptions}. + */ +package com.google.cloud.dataflow.sdk.runners; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleReader.java new file mode 100644 index 000000000000..912c570f8efa --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleReader.java @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import java.io.IOException; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * ApplianceShuffleReader reads chunks of data from a shuffle dataset + * for a position range. + * + * It is a JNI wrapper of an equivalent C++ class. + */ +@ThreadSafe +public final class ApplianceShuffleReader implements ShuffleReader { + static { + ShuffleLibrary.load(); + } + + /** + * Pointer to the underlying native shuffle reader object. + */ + private long nativePointer; + + /** + * @param shuffleReaderConfig opaque configuration for creating a + * shuffle reader + */ + public ApplianceShuffleReader(byte[] shuffleReaderConfig) { + this.nativePointer = createFromConfig(shuffleReaderConfig); + } + + @Override + public void finalize() { + destroy(); + } + + /** + * Native methods for interacting with the underlying native shuffle client + * code. + */ + private native long createFromConfig(byte[] shuffleReaderConfig); + private native void destroy(); + + @Override + public native ReadChunkResult readIncludingPosition( + byte[] startPosition, byte[] endPosition) throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleWriter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleWriter.java new file mode 100644 index 000000000000..d6b3c7518e3e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ApplianceShuffleWriter.java @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import java.io.IOException; +import javax.annotation.concurrent.ThreadSafe; + +/** + * ApplianceShuffleWriter writes chunks of data to a shuffle dataset. + * + * It is a JNI wrapper of an equivalent C++ class. + */ +@ThreadSafe +public final class ApplianceShuffleWriter implements ShuffleWriter { + static { + ShuffleLibrary.load(); + } + + /** + * Pointer to the underlying native shuffle writer code. + */ + private long nativePointer; + + /** + * @param shuffleWriterConfig opaque configuration for creating a + * shuffle writer + * @param bufferSize the writer buffer size + */ + public ApplianceShuffleWriter(byte[] shuffleWriterConfig, + long bufferSize) { + this.nativePointer = createFromConfig(shuffleWriterConfig, bufferSize); + } + + @Override + public void finalize() { + destroy(); + } + + /** + * Native methods for interacting with the underlying native shuffle + * writer code. + */ + private native long createFromConfig(byte[] shuffleWriterConfig, + long bufferSize); + private native void destroy(); + + @Override + public native void write(byte[] chunk) throws IOException; + + @Override + public native void close() throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AssignWindowsParDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AssignWindowsParDoFn.java new file mode 100644 index 000000000000..f1ae7f11b937 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AssignWindowsParDoFn.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBytes; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.AssignWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; + +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A wrapper around an AssignWindowsDoFn. This class is the same as + * NormalParDoFn, except that it gets deserialized differently. + */ +class AssignWindowsParDoFn extends NormalParDoFn { + public static AssignWindowsParDoFn create( + PipelineOptions options, + CloudObject cloudUserFn, + String stepName, + @Nullable List sideInputInfos, + @Nullable List multiOutputInfos, + Integer numOutputs, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler sampler /* unused */) + throws Exception { + Object windowingFn = + SerializableUtils.deserializeFromByteArray( + getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), + "serialized window fn"); + if (!(windowingFn instanceof WindowingFn)) { + throw new Exception( + "unexpected kind of WindowingFn: " + windowingFn.getClass().getName()); + } + + DoFn assignWindowsDoFn = new AssignWindowsDoFn((WindowingFn) windowingFn); + + return new AssignWindowsParDoFn( + options, assignWindowsDoFn, stepName, executionContext, addCounterMutator); + } + + private AssignWindowsParDoFn( + PipelineOptions options, + DoFn fn, + String stepName, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator) { + super( + options, + fn, + PTuple.empty(), + Arrays.asList("output"), + stepName, + executionContext, + addCounterMutator); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSink.java new file mode 100644 index 000000000000..404b2d261fc9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSink.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumWriter; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * A sink that writes Avro files. Records are written to the Avro file as a + * series of byte arrays. The coder provided is used to serialize each record + * into a byte array. + * + * @param the type of the elements written to the sink + */ +public class AvroByteSink extends Sink { + + final AvroSink avroSink; + final Coder coder; + private final Schema schema = Schema.create(Schema.Type.BYTES); + + public AvroByteSink(String filenamePrefix, Coder coder) { + this(filenamePrefix, "", "", 1, coder); + } + + public AvroByteSink(String filenamePrefix, String shardFormat, String filenameSuffix, + int shardCount, Coder coder) { + this.coder = coder; + avroSink = new AvroSink( + filenamePrefix, shardFormat, filenameSuffix, shardCount, + WindowedValue.getValueOnlyCoder(AvroCoder.of(ByteBuffer.class, schema))); + } + + @Override + public SinkWriter writer() throws IOException { + return new AvroByteFileWriter(); + } + + /** The SinkWriter for an AvroByteSink. */ + class AvroByteFileWriter implements SinkWriter { + + private final SinkWriter> avroFileWriter; + + public AvroByteFileWriter() throws IOException { + avroFileWriter = avroSink.writer(new GenericDatumWriter(schema)); + } + + @Override + public long add(T value) throws IOException { + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, value); + ByteBuffer encodedBuffer = ByteBuffer.wrap(encodedElem); + avroFileWriter.add(WindowedValue.valueInGlobalWindow(encodedBuffer)); + return encodedElem.length; + } + + @Override + public void close() throws IOException { + avroFileWriter.close(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSource.java new file mode 100644 index 000000000000..b71700a08fca --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSource.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * A source that reads Avro files. Records are read from the Avro file as a + * series of byte arrays. The coder provided is used to deserialize each record + * from a byte array. + * + * @param the type of the elements read from the source + */ +public class AvroByteSource extends Source { + + final AvroSource avroSource; + final Coder coder; + private final Schema schema = Schema.create(Schema.Type.BYTES); + + public AvroByteSource(String filename, + @Nullable Long startPosition, + @Nullable Long endPosition, + Coder coder) { + this.coder = coder; + avroSource = new AvroSource( + filename, startPosition, endPosition, + WindowedValue.getValueOnlyCoder(AvroCoder.of(ByteBuffer.class, schema))); + } + + @Override + public SourceIterator iterator() throws IOException { + return new AvroByteFileIterator(); + } + + class AvroByteFileIterator extends AbstractSourceIterator { + + private final SourceIterator> avroFileIterator; + + public AvroByteFileIterator() throws IOException { + avroFileIterator = avroSource.iterator( + new GenericDatumReader(schema)); + } + + @Override + public boolean hasNext() throws IOException { + return avroFileIterator.hasNext(); + } + + @Override + public T next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ByteBuffer inBuffer = avroFileIterator.next().getValue(); + byte[] encodedElem = new byte[inBuffer.remaining()]; + inBuffer.get(encodedElem); + assert inBuffer.remaining() == 0; + inBuffer.clear(); + notifyElementRead(encodedElem.length); + return CoderUtils.decodeFromByteArray(coder, encodedElem); + } + + @Override + public void close() throws IOException { + avroFileIterator.close(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSink.java new file mode 100644 index 000000000000..64fe691aa41f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSink.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.WindowedValue.ValueOnlyWindowedValueCoder; +import static com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.util.ShardingWritableByteChannel; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.io.DatumWriter; + +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Random; + +/** + * A sink that writes Avro files. + * + * @param the type of the elements written to the sink + */ +public class AvroSink extends Sink> { + + final String filenamePrefix; + final String shardFormat; + final String filenameSuffix; + final int shardCount; + final AvroCoder avroCoder; + final Schema schema; + + public AvroSink(String filename, WindowedValueCoder coder) { + this(filename, "", "", 1, coder); + } + + public AvroSink(String filenamePrefix, String shardFormat, String filenameSuffix, int shardCount, + WindowedValueCoder coder) { + if (!(coder instanceof ValueOnlyWindowedValueCoder)) { + throw new IllegalArgumentException("Expected ValueOnlyWindowedValueCoder"); + } + + if (!(coder.getValueCoder() instanceof AvroCoder)) { + throw new IllegalArgumentException("AvroSink requires an AvroCoder"); + } + + this.filenamePrefix = filenamePrefix; + this.shardFormat = shardFormat; + this.filenameSuffix = filenameSuffix; + this.shardCount = shardCount; + this.avroCoder = (AvroCoder) coder.getValueCoder(); + this.schema = this.avroCoder.getSchema(); + } + + public SinkWriter> writer(DatumWriter datumWriter) throws IOException { + WritableByteChannel writer = IOChannelUtils.create( + filenamePrefix, shardFormat, filenameSuffix, shardCount, MimeTypes.BINARY); + + if (writer instanceof ShardingWritableByteChannel) { + return new AvroShardingFileWriter(datumWriter, (ShardingWritableByteChannel) writer); + } else { + return new AvroFileWriter(datumWriter, writer); + } + } + + @Override + public SinkWriter> writer() throws IOException { + return writer(avroCoder.createDatumWriter()); + } + + /** The SinkWriter for an AvroSink. */ + class AvroFileWriter implements SinkWriter> { + DataFileWriter fileWriter; + + public AvroFileWriter(DatumWriter datumWriter, WritableByteChannel outputChannel) + throws IOException { + fileWriter = new DataFileWriter<>(datumWriter); + fileWriter.create(schema, Channels.newOutputStream(outputChannel)); + } + + @Override + public long add(WindowedValue value) throws IOException { + fileWriter.append(value.getValue()); + // DataFileWriter doesn't support returning the length written. Use the + // coder instead. + return CoderUtils.encodeToByteArray(avroCoder, value.getValue()).length; + } + + @Override + public void close() throws IOException { + fileWriter.close(); + } + } + + /** The SinkWriter for an AvroSink, which supports sharding. */ + class AvroShardingFileWriter implements SinkWriter> { + private ArrayList fileWriters = new ArrayList<>(); + private final Random random = new Random(); + + public AvroShardingFileWriter( + DatumWriter datumWriter, ShardingWritableByteChannel outputChannel) throws IOException { + for (int i = 0; i < outputChannel.getNumShards(); i++) { + fileWriters.add(new AvroFileWriter(datumWriter, outputChannel.getChannel(i))); + } + } + + @Override + public long add(WindowedValue value) throws IOException { + return fileWriters.get(random.nextInt(fileWriters.size())).add(value); + } + + @Override + public void close() throws IOException { + for (AvroFileWriter fileWriter : fileWriters) { + fileWriter.close(); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactory.java new file mode 100644 index 000000000000..9a20d17aee22 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactory.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +/** + * Creates an AvroSink from a CloudObject spec. + */ +public final class AvroSinkFactory { + // Do not instantiate. + private AvroSinkFactory() {} + + public static Sink create(PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(spec, coder); + } + + static Sink create(CloudObject spec, Coder coder) + throws Exception { + String filename = getString(spec, PropertyNames.FILENAME); + + if (!(coder instanceof WindowedValueCoder)) { + return new AvroByteSink<>(filename, coder); + //throw new IllegalArgumentException("Expected WindowedValueCoder"); + } + + WindowedValueCoder windowedCoder = (WindowedValueCoder) coder; + if (windowedCoder.getValueCoder() instanceof AvroCoder) { + return new AvroSink(filename, windowedCoder); + } else { + return new AvroByteSink<>(filename, windowedCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSource.java new file mode 100644 index 000000000000..3f071cff2c7a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSource.java @@ -0,0 +1,203 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.WindowedValue.ValueOnlyWindowedValueCoder; +import static com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.SeekableInput; +import org.apache.avro.io.DatumReader; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.util.Collection; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * A source that reads Avro files. + * + * @param the type of the elements read from the source + */ +public class AvroSource extends Source> { + private static final int BUF_SIZE = 200; + final String filename; + @Nullable final Long startPosition; + @Nullable final Long endPosition; + final AvroCoder avroCoder; + private final Schema schema; + + public AvroSource(String filename, + @Nullable Long startPosition, + @Nullable Long endPosition, + WindowedValueCoder coder) { + if (!(coder instanceof ValueOnlyWindowedValueCoder)) { + throw new IllegalArgumentException("Expected ValueOnlyWindowedValueCoder"); + } + + if (!(coder.getValueCoder() instanceof AvroCoder)) { + throw new IllegalArgumentException("AvroSource requires an AvroCoder"); + } + + this.filename = filename; + this.startPosition = startPosition; + this.endPosition = endPosition; + this.avroCoder = (AvroCoder) coder.getValueCoder(); + this.schema = this.avroCoder.getSchema(); + } + + public SourceIterator> iterator(DatumReader datumReader) throws IOException { + IOChannelFactory factory = IOChannelUtils.getFactory(filename); + Collection inputs = factory.match(filename); + + if (inputs.size() == 1) { + String input = inputs.iterator().next(); + ReadableByteChannel reader = factory.open(input); + return new AvroFileIterator(datumReader, input, reader, startPosition, endPosition); + + } else { + if (startPosition != null || endPosition != null) { + throw new UnsupportedOperationException( + "Unable to apply range limits to multiple-input stream: " + + filename); + } + return new AvroFileMultiIterator(datumReader, factory, inputs.iterator()); + } + } + + @Override + public SourceIterator> iterator() throws IOException { + return iterator(avroCoder.createDatumReader()); + } + + class AvroFileMultiIterator extends LazyMultiSourceIterator> { + private final IOChannelFactory factory; + private final DatumReader datumReader; + + public AvroFileMultiIterator(DatumReader datumReader, + IOChannelFactory factory, + Iterator inputs) { + super(inputs); + this.factory = factory; + this.datumReader = datumReader; + } + + @Override + protected SourceIterator> open(String input) throws IOException { + return new AvroFileIterator(datumReader, input, factory.open(input), null, null); + } + } + + class AvroFileIterator extends AbstractSourceIterator> { + final DataFileReader fileReader; + final Long endOffset; + + public AvroFileIterator(DatumReader datumReader, + String filename, + ReadableByteChannel reader, + @Nullable Long startOffset, + @Nullable Long endOffset) + throws IOException { + if (!(reader instanceof SeekableByteChannel)) { + throw new UnsupportedOperationException( + "Unable to seek to offset in stream for " + filename); + } + SeekableByteChannel inChannel = (SeekableByteChannel) reader; + SeekableInput seekableInput = new SeekableByteChannelInput(inChannel); + this.fileReader = new DataFileReader<>(seekableInput, datumReader); + this.endOffset = endOffset; + if (startOffset != null && startOffset > 0) { + // Sync to the first record at or after startOffset. + fileReader.sync(startOffset); + } + } + + @Override + public boolean hasNext() throws IOException { + return fileReader.hasNext() + && (endOffset == null || !fileReader.pastSync(endOffset)); + } + + @Override + public WindowedValue next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + T next = fileReader.next(); + // DataFileReader doesn't seem to support getting the current position. + // The difference between tell() calls seems to be zero. Use the coder + // instead. + notifyElementRead(CoderUtils.encodeToByteArray(avroCoder, next).length); + return WindowedValue.valueInGlobalWindow(next); + } + + @Override + public void close() throws IOException { + fileReader.close(); + } + } + + /** + * An implementation of an Avro SeekableInput wrapping a + * SeekableByteChannel. + */ + static class SeekableByteChannelInput implements SeekableInput { + final SeekableByteChannel channel; + + public SeekableByteChannelInput(SeekableByteChannel channel) { + this.channel = channel; + } + + @Override + public void seek(long position) throws IOException { + channel.position(position); + } + + @Override + public long tell() throws IOException { + return channel.position(); + } + + @Override + public long length() throws IOException { + return channel.size(); + } + + @Override + public int read(byte[] b, int offset, int length) throws IOException { + return channel.read(ByteBuffer.wrap(b, offset, length)); + } + + @Override + public void close() throws IOException { + channel.close(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactory.java new file mode 100644 index 000000000000..329d8b66e2ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactory.java @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +/** + * Creates an AvroSource from a CloudObject spec. + */ +public class AvroSourceFactory { + // Do not instantiate. + private AvroSourceFactory() {} + + public static Source create(PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(spec, coder); + } + + static Source create(CloudObject spec, + Coder coder) + throws Exception { + String filename = getString(spec, PropertyNames.FILENAME); + Long startOffset = getLong(spec, PropertyNames.START_OFFSET, null); + Long endOffset = getLong(spec, PropertyNames.END_OFFSET, null); + + if (!(coder instanceof WindowedValueCoder)) { + return new AvroByteSource<>(filename, startOffset, endOffset, coder); + //throw new IllegalArgumentException("Expected WindowedValueCoder"); + } + + WindowedValueCoder windowedCoder = (WindowedValueCoder) coder; + if (windowedCoder.getValueCoder() instanceof AvroCoder) { + return new AvroSource(filename, startOffset, endOffset, windowedCoder); + } else { + return new AvroByteSource<>(filename, startOffset, endOffset, windowedCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySource.java new file mode 100644 index 000000000000..b43c942b3ed9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySource.java @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.util.BigQueryTableRowIterator; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.logging.Logger; + +/** + * A source that reads a BigQuery table and yields TableRow objects. + * + *

The source is a wrapper over the {@code BigQueryTableRowIterator} class which issues a + * query for all rows of a table and then iterates over the result. There is no support for + * progress reporting because the source is used only in situations where the entire table must be + * read by each worker (i.e. the source is used as a side input). + */ +public class BigQuerySource extends Source { + private static final Logger LOG = + Logger.getLogger(BigQuerySource.class.getName()); + + final TableReference tableRef; + final BigQueryOptions bigQueryOptions; + final Bigquery bigQueryClient; + + /** Builds a BigQuery source using pipeline options to instantiate a Bigquery client. */ + public BigQuerySource(BigQueryOptions bigQueryOptions, TableReference tableRef) { + // Save pipeline options so that we can construct the BigQuery client on-demand whenever an + // iterator gets created. + this.bigQueryOptions = bigQueryOptions; + this.tableRef = tableRef; + this.bigQueryClient = null; + } + + /** Builds a BigQuerySource directly using a BigQuery client. */ + public BigQuerySource(Bigquery bigQueryClient, TableReference tableRef) { + this.bigQueryOptions = null; + this.tableRef = tableRef; + this.bigQueryClient = bigQueryClient; + } + + @Override + public SourceIterator iterator() throws IOException { + return new BigQuerySourceIterator( + bigQueryClient != null + ? bigQueryClient + : Transport.newBigQueryClient(bigQueryOptions).build(), + tableRef); + } + + /** + * A SourceIterator that yields TableRow objects for each row of a BigQuery table. + */ + class BigQuerySourceIterator extends AbstractSourceIterator { + + private BigQueryTableRowIterator rowIterator; + + public BigQuerySourceIterator(Bigquery bigQueryClient, TableReference tableRef) { + rowIterator = new BigQueryTableRowIterator(bigQueryClient, tableRef); + } + + @Override + public boolean hasNext() { + return rowIterator.hasNext(); + } + + @Override + public TableRow next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return rowIterator.next(); + } + + @Override + public Progress getProgress() { + // For now reporting progress is not supported because this source is used only when + // an entire table needs to be read by each worker (used as a side input for instance). + throw new UnsupportedOperationException(); + } + + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + // For now updating the stop position is not supported because this source + // is used only when an entire table needs to be read by each worker (used + // as a side input for instance). + checkNotNull(proposedStopPosition); + throw new UnsupportedOperationException(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactory.java new file mode 100644 index 000000000000..682b7faa1400 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactory.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.services.bigquery.model.TableReference; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +/** + * Creates a BigQuerySource from a {@link CloudObject} spec. + */ +public class BigQuerySourceFactory { + // Do not instantiate. + private BigQuerySourceFactory() {} + + public static BigQuerySource create( + PipelineOptions options, CloudObject spec, Coder coder, + ExecutionContext executionContext) throws Exception { + return new BigQuerySource( + options.as(BigQueryOptions.class), + new TableReference() + .setProjectId(getString(spec, PropertyNames.BIGQUERY_PROJECT)) + .setDatasetId(getString(spec, PropertyNames.BIGQUERY_DATASET)) + .setTableId(getString(spec, PropertyNames.BIGQUERY_TABLE))); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ByteArrayShufflePosition.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ByteArrayShufflePosition.java new file mode 100644 index 000000000000..881f61b73020 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ByteArrayShufflePosition.java @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.decodeBase64; +import static com.google.api.client.util.Base64.encodeBase64URLSafeString; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.util.common.worker.ShufflePosition; +import com.google.common.primitives.UnsignedBytes; + +import java.util.Arrays; + +/** + * Represents a ShufflePosition as an array of bytes. + */ +public class ByteArrayShufflePosition implements Comparable, ShufflePosition { + private final byte[] position; + + public ByteArrayShufflePosition(byte[] position) { + this.position = position; + } + + public static ByteArrayShufflePosition fromBase64(String position) { + return ByteArrayShufflePosition.of(decodeBase64(position)); + } + + public static ByteArrayShufflePosition of(byte[] position) { + if (position == null) { + return null; + } + return new ByteArrayShufflePosition(position); + } + + public static byte[] getPosition(ShufflePosition shufflePosition) { + if (shufflePosition == null) { + return null; + } + Preconditions.checkArgument( + shufflePosition instanceof ByteArrayShufflePosition); + ByteArrayShufflePosition adapter = (ByteArrayShufflePosition) shufflePosition; + return adapter.getPosition(); + } + + public byte[] getPosition() { return position; } + + public String encodeBase64() { + return encodeBase64URLSafeString(position); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof ByteArrayShufflePosition) { + ByteArrayShufflePosition that = (ByteArrayShufflePosition) o; + return Arrays.equals(this.position, that.position); + } + return false; + } + + @Override + public int hashCode() { + return Arrays.hashCode(position); + } + + @Override + public String toString() { + return "ShufflePosition(" + (new String(position)) + ")"; + } + + @Override + public int compareTo(Object o) { + if (this == o) { + return 0; + } + return UnsignedBytes.lexicographicalComparator().compare( + position, ((ByteArrayShufflePosition) o).position); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleBatchReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleBatchReader.java new file mode 100644 index 000000000000..6f746ffec8c5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleBatchReader.java @@ -0,0 +1,97 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleBatchReader; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.ShufflePosition; +import com.google.common.io.ByteStreams; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; + +import javax.annotation.Nullable; + +/** + * ChunkingShuffleBatchReader reads data from a shuffle dataset using a + * ShuffleReader. + */ +final class ChunkingShuffleBatchReader implements ShuffleBatchReader { + private ShuffleReader reader; + + /** + * @param reader used to read from a shuffle dataset + */ + public ChunkingShuffleBatchReader(ShuffleReader reader) throws IOException { + this.reader = reader; + } + + @Override + public ShuffleBatchReader.Batch read( + @Nullable ShufflePosition startShufflePosition, + @Nullable ShufflePosition endShufflePosition) throws IOException { + @Nullable byte[] startPosition = + ByteArrayShufflePosition.getPosition(startShufflePosition); + @Nullable byte[] endPosition = + ByteArrayShufflePosition.getPosition(endShufflePosition); + + ShuffleReader.ReadChunkResult result = + reader.readIncludingPosition(startPosition, endPosition); + InputStream input = new ByteArrayInputStream(result.chunk); + ArrayList entries = new ArrayList<>(); + while (input.available() > 0) { + entries.add(getShuffleEntry(input)); + } + return new Batch(entries, result.nextStartPosition == null ? null + : ByteArrayShufflePosition.of(result.nextStartPosition)); + } + + /** + * Extracts a ShuffleEntry by parsing bytes from a given InputStream. + * + * @param input stream to read from + * @return parsed ShuffleEntry + */ + static ShuffleEntry getShuffleEntry(InputStream input) throws IOException { + byte[] position = getFixedLengthPrefixedByteArray(input); + byte[] key = getFixedLengthPrefixedByteArray(input); + byte[] skey = getFixedLengthPrefixedByteArray(input); + byte[] value = getFixedLengthPrefixedByteArray(input); + return new ShuffleEntry(position, key, skey, value); + } + + /** + * Extracts a length-prefix-encoded byte array from a given InputStream. + * + * @param input stream to read from + * @return parsed byte array + */ + static byte[] getFixedLengthPrefixedByteArray(InputStream input) + throws IOException { + DataInputStream dataInputStream = new DataInputStream(input); + int length = dataInputStream.readInt(); + if (length < 0) { + throw new IOException("invalid length: " + length); + } + byte[] data = new byte[(int) length]; + ByteStreams.readFully(dataInputStream, data); + return data; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleEntryWriter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleEntryWriter.java new file mode 100644 index 000000000000..9c55c181aebf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ChunkingShuffleEntryWriter.java @@ -0,0 +1,87 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * ChunkingShuffleEntryWriter buffers ShuffleEntries and writes them + * in batches to a shuffle dataset using a given writer. + */ +@NotThreadSafe +final class ChunkingShuffleEntryWriter implements ShuffleEntryWriter { + // Approximate maximum size of a chunk in bytes. + private static final int MAX_CHUNK_SIZE = 1 << 20; + + private static final byte[] EMPTY_BYTES = new byte[0]; + + private ByteArrayOutputStream chunk = new ByteArrayOutputStream(); + + private final ShuffleWriter writer; + + /** + * @param writer used to write chunks created by this writer + */ + public ChunkingShuffleEntryWriter(ShuffleWriter writer) { + this.writer = checkNotNull(writer); + } + + @Override + public long put(ShuffleEntry entry) throws IOException { + if (chunk.size() >= MAX_CHUNK_SIZE) { + writeChunk(); + } + + DataOutputStream output = new DataOutputStream(chunk); + return putFixedLengthPrefixedByteArray(entry.getKey(), output) + + putFixedLengthPrefixedByteArray(entry.getSecondaryKey(), output) + + putFixedLengthPrefixedByteArray(entry.getValue(), output); + } + + @Override + public void close() throws IOException { + writeChunk(); + writer.close(); + } + + private void writeChunk() throws IOException { + if (chunk.size() > 0) { + writer.write(chunk.toByteArray()); + chunk.reset(); + } + } + + static int putFixedLengthPrefixedByteArray(byte[] data, + DataOutputStream output) + throws IOException { + if (data == null) { + data = EMPTY_BYTES; + } + int bytesWritten = output.size(); + output.writeInt(data.length); + output.write(data, 0, data.length); + return output.size() - bytesWritten; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFn.java new file mode 100644 index 000000000000..16230571fae1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFn.java @@ -0,0 +1,219 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBytes; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.Preconditions; + +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A wrapper around a decoded user value combining function. + */ +public class CombineValuesFn extends NormalParDoFn { + /** + * The optimizer may split run the user combiner in 3 separate + * phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees + * fit. The CombinerPhase dictates which DoFn is actually running in + * the worker. + * + * TODO: These strings are part of the service definition, and + * should be added into the definition of the ParDoInstruction, + * but the protiary definitions don't allow for enums yet. + */ + public static class CombinePhase { + public static final String ALL = "all"; + public static final String ADD = "add"; + public static final String MERGE = "merge"; + public static final String EXTRACT = "extract"; + } + + public static CombineValuesFn create( + PipelineOptions options, + CloudObject cloudUserFn, + String stepName, + @Nullable List sideInputInfos, + @Nullable List multiOutputInfos, + Integer numOutputs, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler /* unused */) + throws Exception { + Object deserializedFn = + SerializableUtils.deserializeFromByteArray( + getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), + "serialized user fn"); + Preconditions.checkArgument( + deserializedFn instanceof Combine.KeyedCombineFn); + Combine.KeyedCombineFn combineFn = (Combine.KeyedCombineFn) deserializedFn; + + // Get the combine phase, default to ALL. (The implementation + // doesn't have to split the combiner). + String phase = getString(cloudUserFn, PropertyNames.PHASE, CombinePhase.ALL); + + Preconditions.checkArgument( + sideInputInfos == null || sideInputInfos.size() == 0, + "unexpected side inputs for CombineValuesFn"); + Preconditions.checkArgument( + numOutputs == 1, "expected exactly one output for CombineValuesFn"); + + DoFn doFn = null; + switch (phase) { + case CombinePhase.ALL: + doFn = new CombineValuesDoFn(combineFn); + break; + case CombinePhase.ADD: + doFn = new AddInputsDoFn(combineFn); + break; + case CombinePhase.MERGE: + doFn = new MergeAccumulatorsDoFn(combineFn); + break; + case CombinePhase.EXTRACT: + doFn = new ExtractOutputDoFn(combineFn); + break; + default: + throw new IllegalArgumentException( + "phase must be one of 'all', 'add', 'merge', 'extract'"); + } + return new CombineValuesFn(options, doFn, stepName, executionContext, addCounterMutator); + } + + private CombineValuesFn( + PipelineOptions options, + DoFn doFn, + String stepName, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator) { + super( + options, + doFn, + PTuple.empty(), + Arrays.asList("output"), + stepName, + executionContext, + addCounterMutator); + } + + /** + * The ALL phase is the unsplit combiner, in case combiner lifting + * is disabled or the optimizer chose not to lift this combiner. + */ + private static class CombineValuesDoFn + extends DoFn>, KV>{ + private final Combine.KeyedCombineFn combineFn; + + private CombineValuesDoFn( + Combine.KeyedCombineFn combineFn) { + this.combineFn = combineFn; + } + + @Override + public void processElement(ProcessContext c) { + KV> kv = (KV>) c.element(); + K key = (K) kv.getKey(); + + c.output(KV.of(key, this.combineFn.apply(key, kv.getValue()))); + } + } + + /** + * ADD phase: KV> -> KV + */ + private static class AddInputsDoFn + extends DoFn>, KV>{ + private final Combine.KeyedCombineFn combineFn; + + private AddInputsDoFn( + Combine.KeyedCombineFn combineFn) { + this.combineFn = combineFn; + } + + @Override + public void processElement(ProcessContext c) { + KV> kv = (KV>) c.element(); + K key = kv.getKey(); + VA accum = this.combineFn.createAccumulator(key); + for (VI input : kv.getValue()) { + this.combineFn.addInput(key, accum, input); + } + + c.output(KV.of(key, accum)); + } + } + + /** + * MERGE phase: KV> -> KV + */ + private static class MergeAccumulatorsDoFn + extends DoFn>, KV>{ + private final Combine.KeyedCombineFn combineFn; + + private MergeAccumulatorsDoFn( + Combine.KeyedCombineFn combineFn) { + this.combineFn = combineFn; + } + + @Override + public void processElement(ProcessContext c) { + KV> kv = (KV>) c.element(); + K key = kv.getKey(); + VA accum = this.combineFn.mergeAccumulators(key, kv.getValue()); + + c.output(KV.of(key, accum)); + } + } + + /** + * EXTRACT phase: KV> -> KV + */ + private static class ExtractOutputDoFn + extends DoFn, KV>{ + private final Combine.KeyedCombineFn combineFn; + + private ExtractOutputDoFn( + Combine.KeyedCombineFn combineFn) { + this.combineFn = combineFn; + } + + @Override + public void processElement(ProcessContext c) { + KV kv = (KV) c.element(); + K key = kv.getKey(); + VO output = this.combineFn.extractOutput(key, kv.getValue()); + + c.output(KV.of(key, output)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannel.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannel.java new file mode 100644 index 000000000000..660b37466557 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannel.java @@ -0,0 +1,270 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; +import static com.google.api.client.util.Preconditions.checkState; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SeekableByteChannel; + +import javax.annotation.concurrent.GuardedBy; + +/** + * A {@link SeekableByteChannel} that adds copy semantics. + * + *

This implementation uses a lock to ensure that only one thread accesses + * the underlying {@code SeekableByteChannel} at any given time. + * + *

{@link SeekableByteChannel#close} is called on the underlying channel once + * all {@code CopyableSeekableByteChannel} objects copied from the initial + * {@code CopyableSeekableByteChannel} are closed. + * + *

The implementation keeps track of the position of each + * {@code CopyableSeekableByteChannel}; on access, it synchronizes with the + * other {@code CopyableSeekableByteChannel} instances accessing the underlying + * channel, seeks to its own position, performs the operation, updates its local + * position, and returns the result. + */ +final class CopyableSeekableByteChannel implements SeekableByteChannel { + /** This particular stream's position in the base stream. */ + private long pos; + + /** + * The synchronization object keeping track of the base + * {@link SeekableByteChannel}, its reference count, and its current position. + * This also doubles as the lock shared by all + * {@link CopyableSeekableByteChannel} instances derived from some original + * instance. + */ + private final Sync sync; + + /** + * Indicates whether this {@link CopyableSeekableByteChannel} is closed. + * + *

Invariant: Unclosed channels own a reference to the base channel, + * allowing us to make {@link #close} idempotent. + * + *

This is only modified under the sync lock. + */ + private boolean closed; + + /** + * Constructs a new {@link CopyableSeekableByteChannel}. The supplied base + * channel will be closed when this channel and all derived channels are + * closed. + */ + public CopyableSeekableByteChannel(SeekableByteChannel base) throws IOException { + this(new Sync(base), 0); + + // Update the position to match the original stream's position. + // + // This doesn't actually need to be synchronized, but it's a little more + // obviously correct to always access sync.position while holding sync's + // internal monitor. + synchronized (sync) { + sync.position = base.position(); + pos = sync.position; + } + } + + /** + * The internal constructor used when deriving a new + * {@link CopyableSeekableByteChannel}. + * + *

N.B. This signature is deliberately incompatible with the public + * constructor. + * + *

Ordinarily, one would implement copy using a copy constructor, and pass + * the object being copied -- but that signature would be compatible with the + * public constructor creating a new set of + * {@code CopyableSeekableByteChannel} objects for some base channel. The + * copy constructor would still be the one called, since its type is more + * specific, but that's fragile; it'd be easy to tweak the signature of the + * constructor used for copies without changing callers, which would silently + * fall back to using the public constructor. So instead, we're careful to + * give this internal constructor its own unique signature. + */ + private CopyableSeekableByteChannel(Sync sync, long pos) { + this.sync = checkNotNull(sync); + checkState(sync.base.isOpen(), + "the base SeekableByteChannel is not open"); + synchronized (sync) { + sync.refCount++; + } + this.pos = pos; + this.closed = false; + } + + /** + * Creates a new {@link CopyableSeekableByteChannel} derived from an existing + * channel, referencing the same base channel. + */ + public CopyableSeekableByteChannel copy() throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + return new CopyableSeekableByteChannel(sync, pos); + } + } + + // SeekableByteChannel implementation + + @Override + public long position() throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + return pos; + } + } + + @Override + public CopyableSeekableByteChannel position(long newPosition) + throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + // Verify that the position is valid for the base channel. + sync.base.position(newPosition); + this.pos = newPosition; + this.sync.position = newPosition; + } + return this; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + reposition(); + int bytesRead = sync.base.read(dst); + notePositionAdded(bytesRead); + return bytesRead; + } + } + + @Override + public long size() throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + return sync.base.size(); + } + } + + @Override + public CopyableSeekableByteChannel truncate(long size) throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + sync.base.truncate(size); + return this; + } + } + + @Override + public int write(ByteBuffer src) throws IOException { + synchronized (sync) { + if (closed) { + throw new ClosedChannelException(); + } + reposition(); + int bytesWritten = sync.base.write(src); + notePositionAdded(bytesWritten); + return bytesWritten; + } + } + + @Override + public boolean isOpen() { + synchronized (sync) { + if (closed) { + return false; + } + return sync.base.isOpen(); + } + } + + @Override + public void close() throws IOException { + synchronized (sync) { + if (closed) { + return; + } + closed = true; + sync.refCount--; + if (sync.refCount == 0) { + sync.base.close(); + } + } + } + + /** + * Updates the base stream's position to match the position required by this + * {@link CopyableSeekableByteChannel}. + */ + @GuardedBy("sync") + private void reposition() throws IOException { + if (pos != sync.position) { + sync.base.position(pos); + sync.position = pos; + } + } + + /** + * Notes that the specified amount has been logically added to the current + * stream's position. + */ + @GuardedBy("sync") + private void notePositionAdded(int amount) { + if (amount < 0) { + return; // Handles EOF indicators. + } + pos += amount; + sync.position += amount; + } + + /** + * A simple value type used to synchronize a set of + * {@link CopyableSeekableByteChannel} instances referencing a single + * underlying channel. + */ + private static final class Sync { + // N.B. Another way to do this would be to implement something like a + // RefcountingForwardingSeekableByteChannel. Doing so would have the + // advantage of clearly isolating the mutable state, at the cost of a lot + // more code. + public final SeekableByteChannel base; + @GuardedBy("this") public long refCount = 0; + @GuardedBy("this") public long position = 0; + + public Sync(SeekableByteChannel base) throws IOException { + this.base = checkNotNull(base); + position = base.position(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CustomSourceFormatFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CustomSourceFormatFactory.java new file mode 100644 index 000000000000..1bb3db228a73 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/CustomSourceFormatFactory.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.services.dataflow.model.Source; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.worker.CustomSourceFormat; + +import java.util.Map; + +/** + * Creates {@code CustomSourceFormat} objects from {@code Source}. + */ +public class CustomSourceFormatFactory { + private CustomSourceFormatFactory() {} + + public static CustomSourceFormat create(Source source) throws Exception { + Map spec = source.getSpec(); + + try { + return InstanceBuilder.ofType(CustomSourceFormat.class) + .fromClassName(getString(spec, PropertyNames.OBJECT_TYPE_NAME)) + .build(); + + } catch (ClassNotFoundException exn) { + throw new Exception( + "unable to create a custom source format from " + source, exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdater.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdater.java new file mode 100644 index 000000000000..f2d41cfcbc45 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdater.java @@ -0,0 +1,121 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.DataflowWorker.buildStatus; +import static com.google.cloud.dataflow.sdk.runners.worker.DataflowWorker.uniqueId; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.WorkItem; +import com.google.api.services.dataflow.model.WorkItemServiceState; +import com.google.api.services.dataflow.model.WorkItemStatus; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.util.common.worker.WorkExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.WorkProgressUpdater; + +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * DataflowWorkProgressUpdater implements the WorkProgressUpdater + * interface for the Cloud Dataflow system. + */ +@NotThreadSafe +public class DataflowWorkProgressUpdater extends WorkProgressUpdater { + private static final Logger LOG = LoggerFactory.getLogger(DataflowWorkProgressUpdater.class); + + /** The Dataflow Worker WorkItem client */ + private final DataflowWorker.WorkUnitClient workUnitClient; + + /** The WorkItem for which work progress updates are sent. */ + private final WorkItem workItem; + + /** Options specifying information about the pipeline run by the worker.*/ + private final DataflowWorkerHarnessOptions options; + + public DataflowWorkProgressUpdater( + WorkItem workItem, + WorkExecutor worker, + DataflowWorker.WorkUnitClient workUnitClient, + DataflowWorkerHarnessOptions options) { + super(worker); + this.workItem = workItem; + this.workUnitClient = workUnitClient; + this.options = options; + } + + @Override + protected String workString() { + return uniqueId(workItem); + } + + @Override + protected long getWorkUnitLeaseExpirationTimestamp() { + return getLeaseExpirationTimestamp(workItem); + } + + @Override + protected void reportProgressHelper() throws Exception { + WorkItemStatus status = buildStatus( + workItem, false /*completed*/, + worker.getOutputCounters(), worker.getOutputMetrics(), options, + worker.getWorkerProgress(), stopPositionToService, + null /*sourceOperationResponse*/, null /*errors*/); + status.setRequestedLeaseDuration(toCloudDuration(Duration.millis(requestedLeaseDurationMs))); + + WorkItemServiceState result = workUnitClient.reportWorkItemStatus(status); + if (result != null) { + // Resets state after a successful progress report. + stopPositionToService = null; + + progressReportIntervalMs = nextProgressReportInterval( + fromCloudDuration(workItem.getReportStatusInterval()).getMillis(), + leaseRemainingTime(getLeaseExpirationTimestamp(result))); + + ApproximateProgress suggestedStopPoint = result.getSuggestedStopPoint(); + if (suggestedStopPoint == null && result.getSuggestedStopPosition() != null) { + suggestedStopPoint = new ApproximateProgress() + .setPosition(result.getSuggestedStopPosition()); + } + + if (suggestedStopPoint != null) { + LOG.info("Proposing stop progress on work unit {} at proposed stopping point {}", + workString(), suggestedStopPoint); + stopPositionToService = + worker.proposeStopPosition( + cloudProgressToSourceProgress(suggestedStopPoint)); + } + } + } + + /** Returns the given work unit's lease expiration timestamp. */ + private long getLeaseExpirationTimestamp(WorkItem workItem) { + return fromCloudTime(workItem.getLeaseExpireTime()).getMillis(); + } + + /** Returns the given work unit service state lease expiration timestamp. */ + private long getLeaseExpirationTimestamp(WorkItemServiceState workItemServiceState) { + return fromCloudTime(workItemServiceState.getLeaseExpireTime()).getMillis(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorker.java new file mode 100644 index 000000000000..5175d15aa882 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorker.java @@ -0,0 +1,330 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudSourceOperationResponseToSourceOperationResponse; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceOperationResponseToCloudSourceOperationResponse; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; + +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.api.services.dataflow.model.Status; +import com.google.api.services.dataflow.model.WorkItem; +import com.google.api.services.dataflow.model.WorkItemServiceState; +import com.google.api.services.dataflow.model.WorkItemStatus; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudCounterUtils; +import com.google.cloud.dataflow.sdk.util.CloudMetricUtils; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.Metric; +import com.google.cloud.dataflow.sdk.util.common.worker.CustomSourceFormat; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.util.common.worker.WorkExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.WorkProgressUpdater; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * This is a semi-abstract harness for executing WorkItem tasks in + * Java workers. Concrete implementations need to implement a + * WorkUnitClient. + * + *

DataflowWorker presents one public interface, + * getAndPerformWork(), which uses the WorkUnitClient to get work, + * execute it, and update the work. + */ +public class DataflowWorker { + + private static final Logger LOG = LoggerFactory.getLogger(DataflowWorker.class); + + /** + * A client to get and update work items. + */ + private final WorkUnitClient workUnitClient; + + /** + * Pipeline options, initially provided via the constructor and + * partially provided via each work work unit. + */ + private final DataflowWorkerHarnessOptions options; + + public DataflowWorker(WorkUnitClient workUnitClient, + DataflowWorkerHarnessOptions options) { + this.workUnitClient = workUnitClient; + this.options = options; + } + + /** + * Gets WorkItem and performs it; returns true if work was + * successfully completed. + * + * getAndPerformWork may throw if there is a failure of the + * WorkUnitClient. + */ + public boolean getAndPerformWork() throws IOException { + WorkItem work = workUnitClient.getWorkItem(); + if (work == null) { + return false; + } + return doWork(work); + } + + /** + * Performs the given work; returns true if successful. + * + * @throws IOException Only if the WorkUnitClient fails. + */ + private boolean doWork(WorkItem workItem) throws IOException { + LOG.info("Executing: {}", workItem); + + WorkExecutor worker = null; + try { + // Populate PipelineOptions with data from work unit. + options.setProject(workItem.getProjectId()); + + ExecutionContext executionContext = new BatchModeExecutionContext(); + + if (workItem.getMapTask() != null) { + worker = MapTaskExecutorFactory.create(options, + workItem.getMapTask(), + executionContext); + + } else if (workItem.getSourceOperationTask() != null) { + worker = SourceOperationExecutorFactory.create( + workItem.getSourceOperationTask()); + + } else { + throw new RuntimeException("unknown kind of work item: " + workItem.toString()); + } + + WorkProgressUpdater progressUpdater = new DataflowWorkProgressUpdater( + workItem, worker, workUnitClient, options); + progressUpdater.startReportingProgress(); + + // Blocks while executing the work. + // TODO: refactor to allow multiple work unit + // processing threads. + worker.execute(); + + // Log all counter values for debugging purposes. + CounterSet counters = worker.getOutputCounters(); + for (Counter counter : counters) { + LOG.info("COUNTER {}.", counter); + } + + // Log all metrics for debugging purposes. + Collection> metrics = worker.getOutputMetrics(); + for (Metric metric : metrics) { + LOG.info("METRIC {}: {}", metric.getName(), metric.getValue()); + } + + // stopReportingProgress can throw an exception if the final progress + // update fails. For correctness, the task must then be marked as failed. + progressUpdater.stopReportingProgress(); + + // Report job success. + + // TODO: Find out a generic way for the WorkExecutor to report work-specific results + // into the work update. + CustomSourceFormat.SourceOperationResponse sourceOperationResponse = + (worker instanceof SourceOperationExecutor) + ? cloudSourceOperationResponseToSourceOperationResponse( + ((SourceOperationExecutor) worker).getResponse()) + : null; + reportStatus(options, "Success", workItem, counters, metrics, sourceOperationResponse, + null /*errors*/); + + return true; + + } catch (Throwable e) { + handleWorkError(workItem, worker, e); + return false; + + } finally { + if (worker != null) { + try { + worker.close(); + } catch (Exception exn) { + LOG.warn("Uncaught exception occurred during work unit shutdown:", exn); + } + } + } + } + + /** Handles the exception thrown when reading and executing the work. */ + private void handleWorkError( + WorkItem workItem, WorkExecutor worker, Throwable e) + throws IOException { + LOG.warn("Uncaught exception occurred during work unit execution:", e); + + // TODO: Look into moving the stack trace thinning + // into the client. + Throwable t = e instanceof UserCodeException ? e.getCause() : e; + Status error = new Status(); + error.setCode(2); // Code.UNKNOWN. TODO: Replace with a generated definition. + // TODO: Attach the stack trace as exception details, not to the message. + error.setMessage(buildCloudStackTrace(t)); + + reportStatus(options, "Failure", workItem, + worker == null ? null : worker.getOutputCounters(), + worker == null ? null : worker.getOutputMetrics(), + null /*sourceOperationResponse*/, + error == null ? null : Collections.singletonList(error)); + } + + /** + * Recursively goes through an exception, pulling out the stack trace. If the + * exception is a chained exception, it recursively goes through any causes + * and appends them to the stack trace. + */ + private static String buildCloudStackTrace(Throwable t) { + StringWriter result = new StringWriter(); + PrintWriter printResult = new PrintWriter(result); + + printResult.print("Exception: "); + for (;;) { + printResult.println(t.toString()); + for (StackTraceElement frame : t.getStackTrace()) { + printResult.println(frame.toString()); + } + t = t.getCause(); + if (t == null) { + break; + } + printResult.print("Caused by: "); + } + return result.toString(); + } + + private void reportStatus(DataflowWorkerHarnessOptions options, + String status, + WorkItem workItem, + @Nullable CounterSet counters, + @Nullable Collection> metrics, + @Nullable CustomSourceFormat.SourceOperationResponse + sourceOperationResponse, + @Nullable List errors) + throws IOException { + LOG.info("{} processing work item {}", status, uniqueId(workItem)); + WorkItemStatus workItemStatus = buildStatus(workItem, true /*completed*/, + counters, metrics, options, null, null, sourceOperationResponse, errors); + workUnitClient.reportWorkItemStatus(workItemStatus); + } + + static WorkItemStatus buildStatus( + WorkItem workItem, + boolean completed, + @Nullable CounterSet counters, + @Nullable Collection> metrics, + DataflowWorkerHarnessOptions options, + @Nullable Source.Progress progress, + @Nullable Source.Position stopPosition, + @Nullable CustomSourceFormat.SourceOperationResponse sourceOperationResponse, + @Nullable List errors) { + WorkItemStatus status = new WorkItemStatus(); + status.setWorkItemId(Long.toString(workItem.getId())); + status.setCompleted(completed); + + List counterUpdates = null; + List metricUpdates = null; + + if (counters != null) { + // Currently we lack a reliable exactly-once delivery mechanism for + // work updates, i.e. they can be retried or reordered, so sending + // delta updates could lead to double-counted or missed contributions. + // However, delta updates may be beneficial for performance. + // TODO: Implement exactly-once delivery and use deltas, + // if it ever becomes clear that deltas are necessary for performance. + boolean delta = false; + counterUpdates = CloudCounterUtils.extractCounters(counters, delta); + } + if (metrics != null) { + metricUpdates = CloudMetricUtils.extractCloudMetrics(metrics, options.getWorkerId()); + } + List updates = null; + if (counterUpdates == null) { + updates = metricUpdates; + } else if (metrics == null) { + updates = counterUpdates; + } else { + updates = new ArrayList<>(); + updates.addAll(counterUpdates); + updates.addAll(metricUpdates); + } + status.setMetricUpdates(updates); + + // TODO: Provide more structure representation of error, + // e.g., the serialized exception object. + if (errors != null) { + status.setErrors(errors); + } + + if (progress != null) { + status.setProgress(sourceProgressToCloudProgress(progress)); + } + if (stopPosition != null) { + status.setStopPosition(sourcePositionToCloudPosition(stopPosition)); + } + + if (workItem.getSourceOperationTask() != null) { + status.setSourceOperationResponse( + sourceOperationResponseToCloudSourceOperationResponse(sourceOperationResponse)); + } + + return status; + } + + static String uniqueId(WorkItem work) { + return work.getProjectId() + ";" + work.getJobId() + ";" + work.getId(); + } + + /** + * Abstract base class describing a client for WorkItem work units. + */ + public abstract static class WorkUnitClient { + /** + * Returns a new WorkItem unit for this Worker to work on or null + * if no work item is available. + */ + public abstract WorkItem getWorkItem() throws IOException; + + /** + * Reports a {@link WorkItemStatus} for an assigned {@link WorkItem}. + * + * @param workItemStatus the status to report + * @return a {@link WorkServiceState} (e.g. a new stop position) + */ + public abstract WorkItemServiceState reportWorkItemStatus( + WorkItemStatus workItemStatus) + throws IOException; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarness.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarness.java new file mode 100644 index 000000000000..fa17cf67390d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarness.java @@ -0,0 +1,231 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudTime; + +import com.google.api.client.util.Preconditions; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.LeaseWorkItemRequest; +import com.google.api.services.dataflow.model.LeaseWorkItemResponse; +import com.google.api.services.dataflow.model.ReportWorkItemStatusRequest; +import com.google.api.services.dataflow.model.ReportWorkItemStatusResponse; +import com.google.api.services.dataflow.model.WorkItem; +import com.google.api.services.dataflow.model.WorkItemServiceState; +import com.google.api.services.dataflow.model.WorkItemStatus; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.worker.logging.DataflowWorkerLoggingFormatter; +import com.google.cloud.dataflow.sdk.runners.worker.logging.DataflowWorkerLoggingInitializer; +import com.google.cloud.dataflow.sdk.util.Credentials; +import com.google.cloud.dataflow.sdk.util.GcsIOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.collect.ImmutableList; + +import org.joda.time.DateTime; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +import java.io.IOException; +import java.lang.Thread.UncaughtExceptionHandler; +import java.util.Collections; +import java.util.List; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * This is a harness for executing WorkItem tasks in Java workers. + *

+ * The worker fetches WorkItem units from the Dataflow Service. + * When the work is complete, the program sends results via the worker service API. + *

+ * Returns status code 0 on successful completion, 1 on any uncaught failures. + *

+ * TODO: add support for VM initialization via config. + * During initialization, we should take a configuration which specifies + * an initialization function, allowing user code to run on VM startup. + */ +public class DataflowWorkerHarness { + private static final Logger LOG = LoggerFactory.getLogger(DataflowWorkerHarness.class); + + private static final String APPLICATION_NAME = "DataflowWorkerHarness"; + + /** + * This uncaught exception handler logs the {@link Throwable} to the logger, {@link System#err} + * and exits the application with status code 1. + */ + static class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { + static final WorkerUncaughtExceptionHandler INSTANCE = new WorkerUncaughtExceptionHandler(); + + @Override + public void uncaughtException(Thread t, Throwable e) { + LOG.error("Uncaught exception in main thread. Exiting with status code 1.", e); + System.err.println("Uncaught exception in main thread. Exiting with status code 1."); + e.printStackTrace(); + System.exit(1); + } + } + + /** + * Fetches and processes work units from the Dataflow service. + */ + public static void main(String[] args) throws Exception { + Thread.currentThread().setUncaughtExceptionHandler(WorkerUncaughtExceptionHandler.INSTANCE); + new DataflowWorkerLoggingInitializer().initialize(); + + DataflowWorker worker = createFromSystemProperties(); + processWork(worker); + } + + // Visible for testing. + static void processWork(DataflowWorker worker) throws IOException { + worker.getAndPerformWork(); + } + + static DataflowWorker createFromSystemProperties() { + return create(PipelineOptionsFactory.createFromSystemProperties()); + } + + static DataflowWorker create(DataflowWorkerHarnessOptions options) { + MDC.put(DataflowWorkerLoggingFormatter.MDC_DATAFLOW_JOB_ID, options.getJobId()); + MDC.put(DataflowWorkerLoggingFormatter.MDC_DATAFLOW_WORKER_ID, options.getWorkerId()); + options.setAppName(APPLICATION_NAME); + + if (options.getGcpCredential() == null) { + try { + // Load the worker credential, otherwise the default is to load user + // credentials. + options.setGcpCredential(Credentials.getWorkerCredential(options)); + Preconditions.checkState(options.getGcpCredential() != null, + "Failed to obtain worker credential"); + } catch (Throwable e) { + LOG.warn("Unable to obtain any valid credentials. Worker inoperable.", e); + return null; + } + } + + // Configure standard IO factories. + IOChannelUtils.setIOFactory("gs", new GcsIOChannelFactory(options)); + + DataflowWorkUnitClient client = DataflowWorkUnitClient.fromOptions(options); + return new DataflowWorker(client, options); + } + + /** + * A Dataflow WorkUnit client that fetches WorkItems from the Dataflow service. + */ + @ThreadSafe + static class DataflowWorkUnitClient extends DataflowWorker.WorkUnitClient { + private final Dataflow dataflow; + private final DataflowWorkerHarnessOptions options; + + /** + * Creates a client that fetches WorkItems from the Dataflow service. + * + * @param options The pipeline options. + * @return A WorkItemClient that fetches WorkItems from the Dataflow service. + */ + static DataflowWorkUnitClient fromOptions(DataflowWorkerHarnessOptions options) { + return new DataflowWorkUnitClient( + Transport.newDataflowClient(options).build(), + options); + } + + /** + * Package private constructor for testing. + */ + DataflowWorkUnitClient(Dataflow dataflow, DataflowWorkerHarnessOptions options) { + this.dataflow = dataflow; + this.options = options; + } + + /** + * Gets a WorkItem from the Dataflow service. + */ + @Override + public WorkItem getWorkItem() throws IOException { + LeaseWorkItemRequest request = new LeaseWorkItemRequest(); + request.setFactory(Transport.getJsonFactory()); + request.setWorkItemTypes(ImmutableList.of( + "map_task", "seq_map_task", "remote_source_task")); + // All remote sources require the "remote_source" capability. Dataflow's + // custom sources are further tagged with the format "custom_source". + request.setWorkerCapabilities(ImmutableList.of( + options.getWorkerId(), "remote_source", PropertyNames.CUSTOM_SOURCE_FORMAT)); + request.setWorkerId(options.getWorkerId()); + request.setCurrentWorkerTime(toCloudTime(DateTime.now())); + + // This shouldn't be necessary, but a valid cloud duration string is + // required by the Google API parsing framework. TODO: Fix the framework + // so that an empty or not-present string can be used as a default value. + request.setRequestedLeaseDuration(toCloudDuration(Duration.standardSeconds(60))); + + LOG.debug("Leasing work: {}", request); + + LeaseWorkItemResponse response = dataflow.v1b3().projects().jobs().workItems().lease( + options.getProject(), options.getJobId(), request).execute(); + LOG.debug("Lease work response: {}", response); + + List workItems = response.getWorkItems(); + if (workItems == null || workItems.isEmpty()) { + // We didn't lease any work + return null; + } else if (workItems.size() > 1){ + throw new IOException( + "This version of the SDK expects no more than one work item from the service: " + + response); + } + + WorkItem work = response.getWorkItems().get(0); + if (work == null || work.getId() == null) { + return null; + } + + MDC.put(DataflowWorkerLoggingFormatter.MDC_DATAFLOW_WORK_ID, Long.toString(work.getId())); + // Looks like the work's a'ight. + return work; + } + + @Override + public WorkItemServiceState reportWorkItemStatus(WorkItemStatus workItemStatus) + throws IOException { + workItemStatus.setFactory(Transport.getJsonFactory()); + LOG.debug("Reporting work status: {}", workItemStatus); + ReportWorkItemStatusResponse result = + dataflow.v1b3().projects().jobs().workItems().reportStatus( + options.getProject(), options.getJobId(), + new ReportWorkItemStatusRequest() + .setWorkerId(options.getWorkerId()) + .setWorkItemStatuses(Collections.singletonList(workItemStatus)) + .setCurrentWorkerTime(toCloudTime(DateTime.now()))) + .execute(); + if (result == null || result.getWorkItemServiceStates() == null + || result.getWorkItemServiceStates().size() != 1) { + throw new IOException( + "This version of the SDK expects exactly one work item service state from the service"); + } + WorkItemServiceState state = result.getWorkItemServiceStates().get(0); + LOG.debug("ReportWorkItemStatus result: {}", state); + return state; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/FileBasedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/FileBasedSource.java new file mode 100644 index 000000000000..beea88747c1c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/FileBasedSource.java @@ -0,0 +1,259 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudPositionToSourcePosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ProgressTracker; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PushbackInputStream; +import java.nio.channels.Channels; +import java.util.Collection; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * Abstract base class for sources that read from files. + * + * @param the type of the elements read from the source + */ +public abstract class FileBasedSource extends Source { + protected static final int BUF_SIZE = 200; + protected final String filename; + @Nullable protected final Long startPosition; + @Nullable protected final Long endPosition; + protected final Coder coder; + protected final boolean useDefaultBufferSize; + + private static final Logger LOG = LoggerFactory.getLogger(FileBasedSource.class); + + protected FileBasedSource(String filename, + @Nullable Long startPosition, + @Nullable Long endPosition, + Coder coder, + boolean useDefaultBufferSize) { + this.filename = filename; + this.startPosition = startPosition; + this.endPosition = endPosition; + this.coder = coder; + this.useDefaultBufferSize = useDefaultBufferSize; + } + + /** + * Returns a new iterator for elements in the given range in the + * given file. If the range starts in the middle an element, this + * element is skipped as it is considered part of the previous + * range; if the last element that starts in the range finishes + * beyond the end position, it is still considered part of this + * range. In other words, the start position and the end position + * are "rounded up" to element boundaries. + * + * @param endPosition offset of the end position; null means end-of-file + */ + protected abstract SourceIterator newSourceIteratorForRangeInFile( + IOChannelFactory factory, String oneFile, long startPosition, + @Nullable Long endPosition) + throws IOException; + + /** + * Returns a new iterator for elements in the given files. Caller + * must ensure that the file collection is not empty. + */ + protected abstract SourceIterator newSourceIteratorForFiles( + IOChannelFactory factory, Collection files) throws IOException; + + @Override + public SourceIterator iterator() throws IOException { + IOChannelFactory factory = IOChannelUtils.getFactory(filename); + Collection inputs = factory.match(filename); + if (inputs.isEmpty()) { + throw new IOException("No match for file pattern '" + filename + "'"); + } + + if (startPosition != null || endPosition != null) { + if (inputs.size() != 1) { + throw new UnsupportedOperationException( + "Unable to apply range limits to multiple-input stream: " + + filename); + } + + return newSourceIteratorForRangeInFile( + factory, inputs.iterator().next(), + startPosition == null ? 0 : startPosition, endPosition); + } else { + return newSourceIteratorForFiles(factory, inputs); + } + } + + /** + * Abstract base class for file-based source iterators. + */ + protected abstract class FileBasedIterator extends AbstractSourceIterator { + protected final CopyableSeekableByteChannel seeker; + protected final PushbackInputStream stream; + protected final Long startOffset; + protected Long endOffset; + protected final ProgressTracker tracker; + protected ByteArrayOutputStream nextElement; + protected boolean nextElementComputed = false; + protected long offset; + + FileBasedIterator(CopyableSeekableByteChannel seeker, + long startOffset, + long offset, + @Nullable Long endOffset, + ProgressTracker tracker) throws IOException { + this.seeker = checkNotNull(seeker); + this.seeker.position(startOffset); + BufferedInputStream bufferedStream = useDefaultBufferSize + ? new BufferedInputStream(Channels.newInputStream(seeker)) + : new BufferedInputStream(Channels.newInputStream(seeker), BUF_SIZE); + this.stream = new PushbackInputStream(bufferedStream, BUF_SIZE); + this.startOffset = startOffset; + this.offset = offset; + this.endOffset = endOffset; + this.tracker = checkNotNull(tracker); + } + + /** + * Reads the next element. + * + * @return a {@code ByteArrayOutputStream} containing the contents + * of the element, or {@code null} if the end of the stream + * has been reached. + * @throws IOException if an I/O error occurs + */ + protected abstract ByteArrayOutputStream readElement() + throws IOException; + + @Override + public boolean hasNext() throws IOException { + computeNextElement(); + return nextElement != null; + } + + @Override + public T next() throws IOException { + advance(); + return CoderUtils.decodeFromByteArray(coder, nextElement.toByteArray()); + } + + void advance() throws IOException { + computeNextElement(); + if (nextElement == null) { + throw new NoSuchElementException(); + } + nextElementComputed = false; + } + + @Override + public Progress getProgress() { + // Currently we assume that only a offset position is reported as + // current progress. Source writer can override this method to update + // other metrics, e.g. completion percentage or remaining time. + com.google.api.services.dataflow.model.Position currentPosition = + new com.google.api.services.dataflow.model.Position(); + currentPosition.setByteOffset(offset); + + ApproximateProgress progress = new ApproximateProgress(); + progress.setPosition(currentPosition); + + return cloudProgressToSourceProgress(progress); + } + + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + checkNotNull(proposedStopPosition); + + // Currently we only support stop position in byte offset of + // CloudPosition in a file-based Source. If stop position in + // other types is proposed, the end position in iterator will + // not be updated, and return null. + com.google.api.services.dataflow.model.ApproximateProgress stopPosition = + sourceProgressToCloudProgress(proposedStopPosition); + if (stopPosition == null) { + LOG.warn( + "A stop position other than CloudPosition is not supported now."); + return null; + } + + Long byteOffset = stopPosition.getPosition().getByteOffset(); + if (byteOffset == null) { + LOG.warn( + "A stop position other than byte offset is not supported in a " + + "file-based Source."); + return null; + } + if (byteOffset <= offset) { + // Proposed stop position is not after the current position: + // No stop position update. + return null; + } + + if (endOffset != null && byteOffset >= endOffset) { + // Proposed stop position is after the current stop (end) position: No + // stop position update. + return null; + } + + this.endOffset = byteOffset; + return cloudPositionToSourcePosition(stopPosition.getPosition()); + } + + /** + * Returns the end offset of the iterator. + * The method is called for test ONLY. + */ + Long getEndOffset() { + return this.endOffset; + } + + @Override + public void close() throws IOException { + stream.close(); + } + + private void computeNextElement() throws IOException { + if (nextElementComputed) { + return; + } + + if (endOffset == null || offset < endOffset) { + nextElement = readElement(); + } else { + nextElement = null; + } + nextElementComputed = true; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupAlsoByWindowsParDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupAlsoByWindowsParDoFn.java new file mode 100644 index 000000000000..adf0435e6e98 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupAlsoByWindowsParDoFn.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBytes; +import static com.google.cloud.dataflow.sdk.util.Structs.getObject; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.util.StreamingGroupAlsoByWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A wrapper around a GroupAlsoByWindowsDoFn. This class is the same as + * NormalParDoFn, except that it gets deserialized differently. + */ +class GroupAlsoByWindowsParDoFn extends NormalParDoFn { + public static GroupAlsoByWindowsParDoFn create( + PipelineOptions options, + CloudObject cloudUserFn, + String stepName, + @Nullable List sideInputInfos, + @Nullable List multiOutputInfos, + Integer numOutputs, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler sampler /* unused */) + throws Exception { + Object windowingFn = + SerializableUtils.deserializeFromByteArray( + getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), + "serialized window fn"); + if (!(windowingFn instanceof WindowingFn)) { + throw new Exception( + "unexpected kind of WindowingFn: " + windowingFn.getClass().getName()); + } + + byte[] serializedCombineFn = getBytes(cloudUserFn, PropertyNames.COMBINE_FN, null); + Object combineFn = null; + if (serializedCombineFn != null) { + combineFn = + SerializableUtils.deserializeFromByteArray(serializedCombineFn, "serialized combine fn"); + if (!(combineFn instanceof KeyedCombineFn)) { + throw new Exception("unexpected kind of KeyedCombineFn: " + combineFn.getClass().getName()); + } + } + + Map inputCoderObject = getObject(cloudUserFn, PropertyNames.INPUT_CODER); + + Coder inputCoder = Serializer.deserialize(inputCoderObject, Coder.class); + if (!(inputCoder instanceof WindowedValueCoder)) { + throw new Exception( + "Expected WindowedValueCoder for inputCoder, got: " + + inputCoder.getClass().getName()); + } + Coder elemCoder = ((WindowedValueCoder) inputCoder).getValueCoder(); + if (!(elemCoder instanceof KvCoder)) { + throw new Exception( + "Expected KvCoder for inputCoder, got: " + elemCoder.getClass().getName()); + } + + DoFn windowingDoFn = StreamingGroupAlsoByWindowsDoFn.create( + (WindowingFn) windowingFn, + ((KvCoder) elemCoder).getValueCoder()); + + return new GroupAlsoByWindowsParDoFn( + options, windowingDoFn, stepName, executionContext, addCounterMutator); + } + + private GroupAlsoByWindowsParDoFn( + PipelineOptions options, + DoFn fn, + String stepName, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator) { + super( + options, + fn, + PTuple.empty(), + Arrays.asList("output"), + stepName, + executionContext, + addCounterMutator); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSource.java new file mode 100644 index 000000000000..2d168879a21b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSource.java @@ -0,0 +1,368 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudPositionToSourcePosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; + +import com.google.api.client.util.Preconditions; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.Reiterable; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; +import com.google.cloud.dataflow.sdk.util.common.worker.BatchingShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.GroupingShuffleEntryIterator; +import com.google.cloud.dataflow.sdk.util.common.worker.KeyGroupedShuffleEntries; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * A source that reads from a shuffled dataset and yields key-grouped data. + * + * @param the type of the keys read from the shuffle + * @param the type of the values read from the shuffle + */ +public class GroupingShuffleSource + extends Source>>> { + private static final Logger LOG = + LoggerFactory.getLogger(GroupingShuffleSource.class); + + final byte[] shuffleReaderConfig; + final String startShufflePosition; + final String stopShufflePosition; + final BatchModeExecutionContext executionContext; + + Coder keyCoder; + Coder valueCoder; + + public GroupingShuffleSource(PipelineOptions options, + byte[] shuffleReaderConfig, + String startShufflePosition, + String stopShufflePosition, + Coder>>> coder, + BatchModeExecutionContext executionContext) + throws Exception { + this.shuffleReaderConfig = shuffleReaderConfig; + this.startShufflePosition = startShufflePosition; + this.stopShufflePosition = stopShufflePosition; + this.executionContext = executionContext; + initCoder(coder); + } + + @Override + public SourceIterator>>> iterator() + throws IOException { + Preconditions.checkArgument(shuffleReaderConfig != null); + return iterator(new BatchingShuffleEntryReader( + new ChunkingShuffleBatchReader(new ApplianceShuffleReader( + shuffleReaderConfig)))); + } + + private void initCoder(Coder>>> coder) throws Exception { + if (!(coder instanceof WindowedValueCoder)) { + throw new Exception( + "unexpected kind of coder for WindowedValue: " + coder); + } + Coder>> elemCoder = + ((WindowedValueCoder>>) coder).getValueCoder(); + if (!(elemCoder instanceof KvCoder)) { + throw new Exception( + "unexpected kind of coder for elements read from " + + "a key-grouping shuffle: " + elemCoder); + } + KvCoder> kvCoder = (KvCoder>) elemCoder; + this.keyCoder = kvCoder.getKeyCoder(); + Coder> kvValueCoder = kvCoder.getValueCoder(); + if (!(kvValueCoder instanceof IterableCoder)) { + throw new Exception( + "unexpected kind of coder for values of KVs read from " + + "a key-grouping shuffle"); + } + IterableCoder iterCoder = (IterableCoder) kvValueCoder; + this.valueCoder = iterCoder.getElemCoder(); + } + + final SourceIterator>>> iterator(ShuffleEntryReader reader) + throws IOException { + return new GroupingShuffleSourceIterator(reader); + } + + /** + * A SourceIterator that reads from a ShuffleEntryReader and groups + * all the values with the same key. + * + *

A key limitation of this implementation is that all iterator accesses + * must by externally synchronized (the iterator objects are not individually + * thread-safe, and the iterators derived from a single original iterator + * access shared state which is not thread-safe). + * + *

To access the current position, the iterator must advance + * on-demand and cache the next batch of key grouped shuffle + * entries. The iterator does not advance a second time in @next() + * to avoid asking the underlying iterator to advance to the next + * key before the caller/user iterates over the values corresponding + * to the current key -- which would introduce a performance + * penalty. + */ + private final class GroupingShuffleSourceIterator + extends AbstractSourceIterator>>> { + // N.B. This class is *not* static; it uses the keyCoder, valueCoder, and + // executionContext from its enclosing GroupingShuffleSource. + + /** The iterator over shuffle entries, grouped by common key. */ + private final Iterator groups; + + /** The stop position. No records with a position at or after + * @stopPosition will be returned. Initialized + * to @AbstractShuffleSource.stopShufflePosition but can be + * dynamically updated via @updateStopPosition() (note that such + * updates can only decrease @stopPosition). + * + *

The granularity of the stop position is such that it can + * only refer to records at the boundary of a key. + */ + private ByteArrayShufflePosition stopPosition = null; + + /** The next group to be consumed, if available */ + private KeyGroupedShuffleEntries nextGroup = null; + + public GroupingShuffleSourceIterator(ShuffleEntryReader reader) { + stopPosition = ByteArrayShufflePosition.fromBase64(stopShufflePosition); + this.groups = + new GroupingShuffleEntryIterator(reader.read( + ByteArrayShufflePosition.fromBase64(startShufflePosition), + stopPosition)) { + @Override + protected void notifyElementRead(long byteSize) { + GroupingShuffleSource.this.notifyElementRead(byteSize); + } + }; + } + + private void advanceIfNecessary() { + if (nextGroup == null && groups.hasNext()) { + nextGroup = groups.next(); + } + } + + @Override + public boolean hasNext() throws IOException { + return hasNextInternal(); + } + + /** + * Returns false if the next group does not exist (i.e., no more + * records available) or the group is beyond @stopPosition. + */ + private boolean hasNextInternal() { + advanceIfNecessary(); + if (nextGroup == null) { + return false; + } + ByteArrayShufflePosition current = + ByteArrayShufflePosition.of(nextGroup.position); + return stopPosition == null || current.compareTo(stopPosition) < 0; + } + + @Override + public WindowedValue>> next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + KeyGroupedShuffleEntries group = nextGroup; + nextGroup = null; + + K key = CoderUtils.decodeFromByteArray(keyCoder, group.key); + if (executionContext != null) { + executionContext.setKey(key); + } + + return WindowedValue.valueInEmptyWindows( + KV.>of(key, new ValuesIterable(group.values))); + } + + /** + * Returns the position before the next {@code KV>} to be returned by the + * {@link GroupingShuffleSourceIterator}. Returns null if the + * {@link GroupingShuffleSourceIterator} is finished. + */ + @Override + public Progress getProgress() { + com.google.api.services.dataflow.model.Position currentPosition = + new com.google.api.services.dataflow.model.Position(); + ApproximateProgress progress = new ApproximateProgress(); + if (hasNextInternal()) { + ByteArrayShufflePosition current = + ByteArrayShufflePosition.of(nextGroup.position); + currentPosition.setShufflePosition(current.encodeBase64()); + } else { + if (stopPosition != null) { + currentPosition.setShufflePosition(stopPosition.encodeBase64()); + } else { + // The original stop position described the end of the + // shuffle-position-space (or infinity) and all records have + // been consumed. + progress.setPercentComplete((float) 1.0); + progress.setRemainingTime(toCloudDuration(Duration.ZERO)); + return cloudProgressToSourceProgress(progress); + } + } + + progress.setPosition(currentPosition); + return cloudProgressToSourceProgress(progress); + } + + /** + * Updates the stop position of the shuffle source to the position proposed. Ignores the + * proposed stop position if it is smaller than or equal to the position before the next + * {@code KV>} to be returned by the {@link GroupingShuffleSourceIterator}. + */ + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + checkNotNull(proposedStopPosition); + com.google.api.services.dataflow.model.Position stopCloudPosition = + sourceProgressToCloudProgress(proposedStopPosition).getPosition(); + if (stopCloudPosition == null) { + LOG.warn( + "A stop position other than a Position is not supported now."); + return null; + } + + if (stopCloudPosition.getShufflePosition() == null) { + LOG.warn( + "A stop position other than shuffle position is not supported in " + + "a grouping shuffle source: " + stopCloudPosition.toString()); + return null; + } + ByteArrayShufflePosition newStopPosition = + ByteArrayShufflePosition.fromBase64(stopCloudPosition.getShufflePosition()); + + if (!hasNextInternal()) { + LOG.warn("Cannot update stop position to " + + stopCloudPosition.getShufflePosition() + + " since all input was consumed."); + return null; + } + ByteArrayShufflePosition current = + ByteArrayShufflePosition.of(nextGroup.position); + if (newStopPosition.compareTo(current) <= 0) { + LOG.warn("Proposed stop position: " + + stopCloudPosition.getShufflePosition() + " <= current position: " + + current.encodeBase64()); + return null; + } + + if (this.stopPosition != null + && newStopPosition.compareTo(this.stopPosition) >= 0) { + LOG.warn("Proposed stop position: " + + stopCloudPosition.getShufflePosition() + + " >= current stop position: " + + this.stopPosition.encodeBase64()); + return null; + } + + this.stopPosition = newStopPosition; + LOG.info("Updated the stop position to " + + stopCloudPosition.getShufflePosition()); + + return cloudPositionToSourcePosition(stopCloudPosition); + } + + /** + * Provides the {@link Reiterable} used to iterate through the values part + * of a {@code KV>} entry produced by a + * {@link GroupingShuffleSource}. + */ + private final class ValuesIterable implements Reiterable { + // N.B. This class is *not* static; it uses the valueCoder from + // its enclosing GroupingShuffleSource. + + private final Reiterable base; + + public ValuesIterable(Reiterable base) { + this.base = checkNotNull(base); + } + + @Override + public ValuesIterator iterator() { + return new ValuesIterator(base.iterator()); + } + } + + /** + * Provides the {@link Reiterator} used to iterate through the values part + * of a {@code KV>} entry produced by a + * {@link GroupingShuffleSource}. + */ + private final class ValuesIterator implements Reiterator { + // N.B. This class is *not* static; it uses the valueCoder from + // its enclosing GroupingShuffleSource. + + private final Reiterator base; + + public ValuesIterator(Reiterator base) { + this.base = checkNotNull(base); + } + + @Override + public boolean hasNext() { + return base.hasNext(); + } + + @Override + public V next() { + ShuffleEntry entry = base.next(); + try { + return CoderUtils.decodeFromByteArray(valueCoder, entry.getValue()); + } catch (IOException exn) { + throw new RuntimeException(exn); + } + } + + @Override + public void remove() { + base.remove(); + } + + @Override + public ValuesIterator copy() { + return new ValuesIterator(base.copy()); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceFactory.java new file mode 100644 index 000000000000..2229a77ddc10 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceFactory.java @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.decodeBase64; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * Creates a GroupingShuffleSource from a CloudObject spec. + */ +public class GroupingShuffleSourceFactory { + // Do not instantiate. + private GroupingShuffleSourceFactory() {} + + public static GroupingShuffleSource create( + PipelineOptions options, + CloudObject spec, + Coder>>> coder, + ExecutionContext executionContext) + throws Exception { + return create(options, spec, coder, + (BatchModeExecutionContext) executionContext); + } + + static GroupingShuffleSource create( + PipelineOptions options, + CloudObject spec, + Coder>>> coder, + BatchModeExecutionContext executionContext) + throws Exception { + return new GroupingShuffleSource<>( + options, + decodeBase64(getString(spec, PropertyNames.SHUFFLE_READER_CONFIG)), + getString(spec, PropertyNames.START_SHUFFLE_POSITION, null), + getString(spec, PropertyNames.END_SHUFFLE_POSITION, null), + coder, + executionContext); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySource.java new file mode 100644 index 000000000000..a0a524ee0c9b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySource.java @@ -0,0 +1,163 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkNotNull; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudPositionToSourcePosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static java.lang.Math.min; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * A source that yields a set of precomputed elements. + * + * @param the type of the elements read from the source + */ +public class InMemorySource extends Source { + private static final Logger LOG = LoggerFactory.getLogger(InMemorySource.class); + + final List encodedElements; + final int startIndex; + final int endIndex; + final Coder coder; + + public InMemorySource(List encodedElements, + @Nullable Long startIndex, + @Nullable Long endIndex, + Coder coder) { + this.encodedElements = encodedElements; + int maxIndex = encodedElements.size(); + if (startIndex == null) { + this.startIndex = 0; + } else { + if (startIndex < 0) { + throw new IllegalArgumentException("start index should be >= 0"); + } + this.startIndex = (int) min(startIndex, maxIndex); + } + if (endIndex == null) { + this.endIndex = maxIndex; + } else { + if (endIndex < this.startIndex) { + throw new IllegalArgumentException( + "end index should be >= start index"); + } + this.endIndex = (int) min(endIndex, maxIndex); + } + this.coder = coder; + } + + @Override + public SourceIterator iterator() throws IOException { + return new InMemorySourceIterator(); + } + + /** + * A SourceIterator that yields an in-memory list of elements. + */ + class InMemorySourceIterator extends AbstractSourceIterator { + int index; + int endPosition; + + public InMemorySourceIterator() { + index = startIndex; + endPosition = endIndex; + } + + @Override + public boolean hasNext() { + return index < endPosition; + } + + @Override + public T next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + String encodedElementString = encodedElements.get(index++); + // TODO: Replace with the real encoding used by the + // front end, when we know what it is. + byte[] encodedElement = + StringUtils.jsonStringToByteArray(encodedElementString); + notifyElementRead(encodedElement.length); + return CoderUtils.decodeFromByteArray(coder, encodedElement); + } + + @Override + public Progress getProgress() { + // Currently we assume that only a record index position is reported as + // current progress. Source writer can override this method to update + // other metrics, e.g. completion percentage or remaining time. + com.google.api.services.dataflow.model.Position currentPosition = + new com.google.api.services.dataflow.model.Position(); + currentPosition.setRecordIndex((long) index); + + ApproximateProgress progress = new ApproximateProgress(); + progress.setPosition(currentPosition); + + return cloudProgressToSourceProgress(progress); + } + + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + checkNotNull(proposedStopPosition); + + // Currently we only support stop position in record index of + // an API Position in InMemorySource. If stop position in other types is + // proposed, the end position in iterator will not be updated, + // and return null. + com.google.api.services.dataflow.model.Position stopPosition = + sourceProgressToCloudProgress(proposedStopPosition).getPosition(); + if (stopPosition == null) { + LOG.warn( + "A stop position other than a Dataflow API Position is not currently supported."); + return null; + } + + Long recordIndex = stopPosition.getRecordIndex(); + if (recordIndex == null) { + LOG.warn( + "A stop position other than record index is not supported in InMemorySource."); + return null; + } + if (recordIndex <= index || recordIndex >= endPosition) { + // Proposed stop position is not after the current position or proposed + // stop position is after the current stop (end) position: No stop + // position update. + return null; + } + + this.endPosition = recordIndex.intValue(); + return cloudPositionToSourcePosition(stopPosition); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactory.java new file mode 100644 index 000000000000..3f2cd9c9a1db --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactory.java @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getStrings; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import java.util.Collections; + +/** + * Creates an InMemorySource from a CloudObject spec. + */ +public class InMemorySourceFactory { + // Do not instantiate. + private InMemorySourceFactory() {} + + public static InMemorySource create(PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(spec, coder); + } + + static InMemorySource create(CloudObject spec, + Coder coder) throws Exception { + return new InMemorySource<>( + getStrings(spec, + PropertyNames.ELEMENTS, Collections.emptyList()), + getLong(spec, PropertyNames.START_INDEX, null), + getLong(spec, PropertyNames.END_INDEX, null), + coder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/LazyMultiSourceIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/LazyMultiSourceIterator.java new file mode 100644 index 000000000000..3ccebd561756 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/LazyMultiSourceIterator.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * Implements a SourceIterator over a collection of inputs. + * + * The sources are used sequentially, each consumed entirely before moving + * to the next source. + * + * The input is lazily constructed by using the abstract method {@code open} to + * create a source iterator for inputs on demand. This allows the resources to + * be produced lazily, as an open source iterator may consume process resources + * such as file descriptors. + */ +abstract class LazyMultiSourceIterator + extends Source.AbstractSourceIterator { + private final Iterator inputs; + Source.SourceIterator current; + + public LazyMultiSourceIterator(Iterator inputs) { + this.inputs = inputs; + } + + @Override + public boolean hasNext() throws IOException { + while (selectSource()) { + if (!current.hasNext()) { + current.close(); + current = null; + } else { + return true; + } + } + return false; + } + + @Override + public T next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.next(); + } + + @Override + public void close() throws IOException { + while (selectSource()) { + current.close(); + current = null; + } + } + + protected abstract Source.SourceIterator open(String input) + throws IOException; + + boolean selectSource() throws IOException { + if (current != null) { + return true; + } + if (inputs.hasNext()) { + current = open(inputs.next()); + return true; + } + return false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactory.java new file mode 100644 index 000000000000..095aa0876ee8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactory.java @@ -0,0 +1,413 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.services.dataflow.model.FlattenInstruction; +import com.google.api.services.dataflow.model.InstructionInput; +import com.google.api.services.dataflow.model.InstructionOutput; +import com.google.api.services.dataflow.model.MapTask; +import com.google.api.services.dataflow.model.ParDoInstruction; +import com.google.api.services.dataflow.model.ParallelInstruction; +import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction; +import com.google.api.services.dataflow.model.ReadInstruction; +import com.google.api.services.dataflow.model.WriteInstruction; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservable; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.util.common.worker.FlattenOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.MapTaskExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.Operation; +import com.google.cloud.dataflow.sdk.util.common.worker.OutputReceiver; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoFn; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.ReadOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.ReceivingOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.common.worker.WriteOperation; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Creates a MapTaskExecutor from a MapTask definition. + */ +public class MapTaskExecutorFactory { + /** + * Creates a new MapTaskExecutor from the given MapTask definition. + */ + public static MapTaskExecutor create(PipelineOptions options, + MapTask mapTask, + ExecutionContext context) + throws Exception { + List operations = new ArrayList<>(); + CounterSet counters = new CounterSet(); + String counterPrefix = mapTask.getStageName() + "-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counters.getAddCounterMutator()); + // Open-ended state. + stateSampler.setState("other"); + + // Instantiate operations for each instruction in the graph. + for (ParallelInstruction instruction : mapTask.getInstructions()) { + operations.add( + createOperation(options, instruction, context, operations, + counterPrefix, counters.getAddCounterMutator(), + stateSampler)); + } + + return new MapTaskExecutor(operations, counters, stateSampler); + } + + /** + * Creates an Operation from the given ParallelInstruction definition. + */ + static Operation createOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + if (instruction.getRead() != null) { + return createReadOperation( + options, instruction, executionContext, priorOperations, + counterPrefix, addCounterMutator, stateSampler); + } else if (instruction.getWrite() != null) { + return createWriteOperation( + options, instruction, executionContext, priorOperations, + counterPrefix, addCounterMutator, stateSampler); + } else if (instruction.getParDo() != null) { + return createParDoOperation( + options, instruction, executionContext, priorOperations, + counterPrefix, addCounterMutator, stateSampler); + } else if (instruction.getPartialGroupByKey() != null) { + return createPartialGroupByKeyOperation( + options, instruction, executionContext, priorOperations, + counterPrefix, addCounterMutator, stateSampler); + } else if (instruction.getFlatten() != null) { + return createFlattenOperation( + options, instruction, executionContext, priorOperations, + counterPrefix, addCounterMutator, stateSampler); + } else { + throw new Exception("Unexpected instruction: " + instruction); + } + } + + static ReadOperation createReadOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + ReadInstruction read = instruction.getRead(); + + Source source = + SourceFactory.create(options, read.getSource(), executionContext); + + OutputReceiver[] receivers = createOutputReceivers( + instruction, counterPrefix, addCounterMutator, stateSampler, 1); + + return new ReadOperation(instruction.getSystemName(), source, receivers, + counterPrefix, addCounterMutator, stateSampler); + } + + static WriteOperation createWriteOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + WriteInstruction write = instruction.getWrite(); + + Sink sink = SinkFactory.create(options, write.getSink(), executionContext); + + OutputReceiver[] receivers = createOutputReceivers( + instruction, counterPrefix, addCounterMutator, stateSampler, 0); + + WriteOperation operation = + new WriteOperation(instruction.getSystemName(), sink, receivers, + counterPrefix, addCounterMutator, stateSampler); + + attachInput(operation, write.getInput(), priorOperations); + + return operation; + } + + static ParDoOperation createParDoOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + ParDoInstruction parDo = instruction.getParDo(); + + ParDoFn fn = ParDoFnFactory.create( + options, + CloudObject.fromSpec(parDo.getUserFn()), + instruction.getSystemName(), + parDo.getSideInputs(), + parDo.getMultiOutputInfos(), + parDo.getNumOutputs(), + executionContext, + addCounterMutator, + stateSampler); + + OutputReceiver[] receivers = + createOutputReceivers(instruction, counterPrefix, addCounterMutator, + stateSampler, parDo.getNumOutputs()); + + ParDoOperation operation = + new ParDoOperation(instruction.getSystemName(), fn, receivers, + counterPrefix, addCounterMutator, stateSampler); + + attachInput(operation, parDo.getInput(), priorOperations); + + return operation; + } + + static PartialGroupByKeyOperation createPartialGroupByKeyOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + PartialGroupByKeyInstruction pgbk = instruction.getPartialGroupByKey(); + + Coder coder = Serializer.deserialize(pgbk.getInputElementCodec(), Coder.class); + if (!(coder instanceof WindowedValueCoder)) { + throw new Exception( + "unexpected kind of input coder for PartialGroupByKeyOperation: " + coder); + } + Coder elemCoder = ((WindowedValueCoder) coder).getValueCoder(); + if (!(elemCoder instanceof KvCoder)) { + throw new Exception( + "unexpected kind of input element coder for PartialGroupByKeyOperation: " + elemCoder); + } + KvCoder kvCoder = (KvCoder) elemCoder; + Coder keyCoder = kvCoder.getKeyCoder(); + Coder valueCoder = kvCoder.getValueCoder(); + + OutputReceiver[] receivers = createOutputReceivers( + instruction, counterPrefix, addCounterMutator, stateSampler, 1); + + PartialGroupByKeyOperation operation = + new PartialGroupByKeyOperation(instruction.getSystemName(), + new CoderGroupingKeyCreator(keyCoder), + new CoderSizeEstimator(keyCoder), + new CoderSizeEstimator(valueCoder), + 0.001 /*sizeEstimatorSampleRate*/, + PairInfo.create(), + receivers, + counterPrefix, addCounterMutator, + stateSampler); + + attachInput(operation, pgbk.getInput(), priorOperations); + + return operation; + } + + /** + * Implements PGBKOp.PairInfo via KVs. + */ + public static class PairInfo implements PartialGroupByKeyOperation.PairInfo { + private static PairInfo theInstance = new PairInfo(); + public static PairInfo create() { return theInstance; } + private PairInfo() {} + @Override + public Object getKeyFromInputPair(Object pair) { + WindowedValue> windowedKv = (WindowedValue>) pair; + return windowedKv.getValue().getKey(); + } + @Override + public Object getValueFromInputPair(Object pair) { + WindowedValue> windowedKv = (WindowedValue>) pair; + return windowedKv.getValue().getValue(); + } + @Override + public Object makeOutputPair(Object key, Object values) { + return WindowedValue.valueInEmptyWindows(KV.of(key, values)); + } + } + + /** + * Implements PGBKOp.GroupingKeyCreator via Coder. + */ + public static class CoderGroupingKeyCreator + implements PartialGroupByKeyOperation.GroupingKeyCreator { + final Coder coder; + + public CoderGroupingKeyCreator(Coder coder) { + this.coder = coder; + } + + @Override + public Object createGroupingKey(Object value) throws Exception { + return new PartialGroupByKeyOperation.StructuralByteArray( + CoderUtils.encodeToByteArray(coder, value)); + } + } + + /** + * Implements PGBKOp.SizeEstimator via Coder. + */ + public static class CoderSizeEstimator + implements PartialGroupByKeyOperation.SizeEstimator { + final Coder coder; + + public CoderSizeEstimator(Coder coder) { + this.coder = coder; + } + + @Override + public long estimateSize(Object value) throws Exception { + return CoderUtils.encodeToByteArray(coder, value).length; + } + } + + static FlattenOperation createFlattenOperation( + PipelineOptions options, + ParallelInstruction instruction, + ExecutionContext executionContext, + List priorOperations, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + FlattenInstruction flatten = instruction.getFlatten(); + + OutputReceiver[] receivers = + createOutputReceivers(instruction, counterPrefix, addCounterMutator, + stateSampler, 1); + + FlattenOperation operation = + new FlattenOperation(instruction.getSystemName(), receivers, + counterPrefix, addCounterMutator, stateSampler); + + for (InstructionInput input : flatten.getInputs()) { + attachInput(operation, input, priorOperations); + } + + return operation; + } + + /** + * Returns an array of OutputReceivers for the given + * ParallelInstruction definition. + */ + static OutputReceiver[] createOutputReceivers( + ParallelInstruction instruction, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler, + int expectedNumOutputs) + throws Exception { + int numOutputs = 0; + if (instruction.getOutputs() != null) { + numOutputs = instruction.getOutputs().size(); + } + if (numOutputs != expectedNumOutputs) { + throw new AssertionError( + "ParallelInstruction.Outputs has an unexpected length"); + } + OutputReceiver[] receivers = new OutputReceiver[numOutputs]; + for (int i = 0; i < numOutputs; i++) { + InstructionOutput cloudOutput = instruction.getOutputs().get(i); + receivers[i] = new OutputReceiver( + cloudOutput.getName(), + new ElementByteSizeObservableCoder( + Serializer.deserialize(cloudOutput.getCodec(), Coder.class)), + counterPrefix, + addCounterMutator); + } + return receivers; + } + + /** + * Adapts a Coder to the ElementByteSizeObservable interface. + */ + public static class ElementByteSizeObservableCoder + implements ElementByteSizeObservable { + final Coder coder; + + public ElementByteSizeObservableCoder(Coder coder) { + this.coder = coder; + } + + @Override + public boolean isRegisterByteSizeObserverCheap(T value) { + return coder.isRegisterByteSizeObserverCheap(value, Coder.Context.OUTER); + } + + @Override + public void registerByteSizeObserver(T value, + ElementByteSizeObserver observer) + throws Exception { + coder.registerByteSizeObserver(value, observer, Coder.Context.OUTER); + } + } + + /** + * Adds an input to the given Operation, coming from the given + * producer instruction output. + */ + static void attachInput(ReceivingOperation operation, + @Nullable InstructionInput input, + List priorOperations) { + Integer producerInstructionIndex = 0; + Integer outputNum = 0; + if (input != null) { + if (input.getProducerInstructionIndex() != null) { + producerInstructionIndex = input.getProducerInstructionIndex(); + } + if (input.getOutputNum() != null) { + outputNum = input.getOutputNum(); + } + } + // Input id must refer to an operation that has already been seen. + Operation source = priorOperations.get(producerInstructionIndex); + operation.attachInput(source, outputNum); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFn.java new file mode 100644 index 000000000000..c6e5f9f163e3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFn.java @@ -0,0 +1,214 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBytes; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunner.OutputManager; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.OutputReceiver; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoFn; +import com.google.cloud.dataflow.sdk.util.common.worker.Receiver; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A wrapper around a decoded user DoFn. + */ +public class NormalParDoFn extends ParDoFn { + public static NormalParDoFn create( + PipelineOptions options, + CloudObject cloudUserFn, + String stepName, + @Nullable List sideInputInfos, + @Nullable List multiOutputInfos, + Integer numOutputs, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler /* ignored */) + throws Exception { + Object deserializedFn = + SerializableUtils.deserializeFromByteArray( + getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), + "serialized user fn"); + if (!(deserializedFn instanceof DoFn)) { + throw new Exception("unexpected kind of DoFn: " + deserializedFn.getClass().getName()); + } + DoFn fn = (DoFn) deserializedFn; + + PTuple sideInputValues = PTuple.empty(); + if (sideInputInfos != null) { + for (SideInputInfo sideInputInfo : sideInputInfos) { + Object sideInputValue = SideInputUtils.readSideInput( + options, sideInputInfo, executionContext); + TupleTag tag = new TupleTag(sideInputInfo.getTag()); + sideInputValues = sideInputValues.and(tag, sideInputValue); + } + } + + List outputTags = new ArrayList<>(); + if (multiOutputInfos != null) { + for (MultiOutputInfo multiOutputInfo : multiOutputInfos) { + outputTags.add(multiOutputInfo.getTag()); + } + } + if (outputTags.isEmpty()) { + // Legacy support: assume there's a single output tag named "output". + // (The output tag name will be ignored, for the main output.) + outputTags.add("output"); + } + if (numOutputs != outputTags.size()) { + throw new AssertionError( + "unexpected number of outputTags for DoFn"); + } + + return new NormalParDoFn(options, fn, sideInputValues, outputTags, + stepName, executionContext, addCounterMutator); + } + + public final PipelineOptions options; + public final DoFn fn; + public final PTuple sideInputValues; + public final TupleTag mainOutputTag; + public final List> sideOutputTags; + public final String stepName; + public final ExecutionContext executionContext; + private final CounterSet.AddCounterMutator addCounterMutator; + + /** The DoFnRunner executing a batch. Null between batches. */ + DoFnRunner fnRunner; + + public NormalParDoFn(PipelineOptions options, + DoFn fn, + PTuple sideInputValues, + List outputTags, + String stepName, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator) { + this.options = options; + this.fn = fn; + this.sideInputValues = sideInputValues; + if (outputTags.size() < 1) { + throw new AssertionError("expected at least one output"); + } + this.mainOutputTag = new TupleTag(outputTags.get(0)); + this.sideOutputTags = new ArrayList<>(); + if (outputTags.size() > 1) { + for (String tag : outputTags.subList(1, outputTags.size())) { + this.sideOutputTags.add(new TupleTag(tag)); + } + } + this.stepName = stepName; + this.executionContext = executionContext; + this.addCounterMutator = addCounterMutator; + } + + @Override + public void startBundle(final Receiver... receivers) throws Exception { + if (receivers.length != sideOutputTags.size() + 1) { + throw new AssertionError( + "unexpected number of receivers for DoFn"); + } + + StepContext stepContext = null; + if (executionContext != null) { + stepContext = executionContext.getStepContext(stepName); + } + + fnRunner = DoFnRunner.create( + options, + fn, + sideInputValues, + new OutputManager() { + final Map, OutputReceiver> undeclaredOutputs = + new HashMap<>(); + + @Override + public Receiver initialize(TupleTag tag) { + // Declared outputs. + if (tag.equals(mainOutputTag)) { + return receivers[0]; + } else if (sideOutputTags.contains(tag)) { + return receivers[sideOutputTags.indexOf(tag) + 1]; + } + + // Undeclared outputs. + OutputReceiver receiver = undeclaredOutputs.get(tag); + if (receiver == null) { + // A new undeclared output. + // TODO: plumb through the operationName, so that we can + // name implicit outputs after it. + String outputName = "implicit-" + tag.getId(); + // TODO: plumb through the counter prefix, so we can + // make it available to the OutputReceiver class in case + // it wants to use it in naming output counters. (It + // doesn't today.) + String counterPrefix = ""; + receiver = new OutputReceiver( + outputName, counterPrefix, addCounterMutator); + undeclaredOutputs.put(tag, receiver); + } + return receiver; + } + + @Override + public void output(Receiver receiver, WindowedValue output) { + try { + receiver.process(output); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }, + mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator); + + fnRunner.startBundle(); + } + + @Override + public void processElement(Object elem) throws Exception { + fnRunner.processElement((WindowedValue) elem); + } + + @Override + public void finishBundle() throws Exception { + fnRunner.finishBundle(); + fnRunner = null; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCode.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCode.java new file mode 100644 index 000000000000..487420ce3934 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCode.java @@ -0,0 +1,678 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.common.math.LongMath; +import com.google.common.primitives.Longs; + +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Arrays; + +/** + * This module provides routines for encoding a sequence of typed + * entities into a byte array. The resulting byte arrays can be + * lexicographically compared to yield the same comparison value that + * would have been generated if the encoded items had been compared + * one by one according to their type. + * + * More precisely, suppose: + * 1. byte array A is generated by encoding the sequence of items [A_1..A_n] + * 2. byte array B is generated by encoding the sequence of items [B_1..B_n] + * 3. The types match; i.e., for all i: A_i was encoded using + * the same routine as B_i + * Then: + * Comparing A vs. B lexicographically is the same as comparing + * the vectors [A_1..A_n] and [B_1..B_n] lexicographically. + * + *

+ * This class is NOT thread safe. + */ +public class OrderedCode { + // We want to encode a few extra symbols in strings: + // Separator between items + // Infinite string + // + // Therefore we need an alphabet with at least 258 characters. We + // achieve this by using two-letter sequences starting with '\0' and '\xff' + // as extra symbols: + // encoded as => \0\1 + // \0 encoded as => \0\xff + // \xff encoded as => \xff\x00 + // encoded as => \xff\xff + // + // The remaining two letter sequences starting with '\0' and '\xff' + // are currently unused. + + public static final byte ESCAPE1 = 0x00; + public static final byte NULL_CHARACTER = + (byte) 0xff; // Combined with ESCAPE1 + public static final byte SEPARATOR = 0x01; // Combined with ESCAPE1 + + public static final byte ESCAPE2 = (byte) 0xff; + public static final byte INFINITY = + (byte) 0xff; // Combined with ESCAPE2 + public static final byte FF_CHARACTER = 0x00; // Combined with ESCAPE2 + + public static final byte[] ESCAPE1_SEPARATOR = { ESCAPE1, SEPARATOR }; + + public static final byte[] INFINITY_ENCODED = { ESCAPE2, INFINITY }; + + /** + * This array maps encoding length to header bits in the first two bytes for + * SignedNumIncreasing encoding. + */ + private static final byte[][] LENGTH_TO_HEADER_BITS = { + { 0, 0 }, + { (byte) 0x80, 0 }, + { (byte) 0xc0, 0 }, + { (byte) 0xe0, 0 }, + { (byte) 0xf0, 0 }, + { (byte) 0xf8, 0 }, + { (byte) 0xfc, 0 }, + { (byte) 0xfe, 0 }, + { (byte) 0xff, 0 }, + { (byte) 0xff, (byte) 0x80 }, + { (byte) 0xff, (byte) 0xc0 } + }; + + /** + * This array maps encoding lengths to the header bits that overlap with + * the payload and need fixing during readSignedNumIncreasing. + */ + private static final long[] LENGTH_TO_MASK = { + 0L, + 0x80L, + 0xc000L, + 0xe00000L, + 0xf0000000L, + 0xf800000000L, + 0xfc0000000000L, + 0xfe000000000000L, + 0xff00000000000000L, + 0x8000000000000000L, + 0L + }; + + /** + * This array maps the number of bits in a number to the encoding + * length produced by WriteSignedNumIncreasing. + * For positive numbers, the number of bits is 1 plus the most significant + * bit position (the highest bit position in a positive long is 63). + * For a negative number n, we count the bits in ~n. + * That is, length = BITS_TO_LENGTH[log2Floor(n < 0 ? ~n : n) + 1]. + */ + private static final short[] BITS_TO_LENGTH = { + 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, + 10 + }; + + /** + * stores the current encoded value as a list of byte arrays. Note that this + * is manipulated as we read/write items. + * Note that every item will fit on at most one array. One array may + * have more than one item (eg when used for decoding). While encoding, + * one array will have exactly one item. While returning the encoded array + * we will merge all the arrays in this list. + */ + private final ArrayList encodedArrays = new ArrayList<>(); + + /** + * This is the current position on the first array. Will be non-zero + * only if the ordered code was created using encoded byte array. + */ + private int firstArrayPosition = 0; + + /** + * Creates OrderedCode from scractch. Typically used at encoding time. + */ + public OrderedCode(){ + } + + /** + * Creates OrderedCode from a given encoded byte array. Typically used at + * decoding time. + * + *

+ * For better performance, it uses the input array provided (not a copy). + * Therefore the input array should not be modified. + */ + public OrderedCode(byte[] encodedByteArray) { + encodedArrays.add(encodedByteArray); + } + + /** + * Adds the given byte array item to the OrderedCode. It encodes the input + * byte array, followed by a separator and appends the result to its + * internal encoded byte array store. + * + *

+ * It works with the input array, + * so the input array 'value' should not be modified till the method returns. + * + * @param value bytes to be written. + * @see #readBytes() + */ + public void writeBytes(byte[] value) { + // Determine the length of the encoded array + int encodedLength = 2; // for separator + for (byte b : value) { + if ((b == ESCAPE1) || (b == ESCAPE2)) { + encodedLength += 2; + } else { + encodedLength++; + } + } + + byte[] encodedArray = new byte[encodedLength]; + int copyStart = 0; + int outIndex = 0; + for (int i = 0; i < value.length; i++) { + byte b = value[i]; + if (b == ESCAPE1) { + System.arraycopy(value, copyStart, encodedArray, outIndex, + i - copyStart); + outIndex += i - copyStart; + encodedArray[outIndex++] = ESCAPE1; + encodedArray[outIndex++] = NULL_CHARACTER; + copyStart = i + 1; + } else if (b == ESCAPE2) { + System.arraycopy(value, copyStart, encodedArray, outIndex, + i - copyStart); + outIndex += i - copyStart; + encodedArray[outIndex++] = ESCAPE2; + encodedArray[outIndex++] = FF_CHARACTER; + copyStart = i + 1; + } + } + if (copyStart < value.length) { + System.arraycopy(value, copyStart, encodedArray, outIndex, + value.length - copyStart); + outIndex += value.length - copyStart; + } + encodedArray[outIndex++] = ESCAPE1; + encodedArray[outIndex] = SEPARATOR; + + encodedArrays.add(encodedArray); + } + + /** + * Encodes the long item, in big-endian format, and appends the result to its + * internal encoded byte array store. + * + * Note that the specified long is treated like a uint64, e.g. + * {@code new OrderedCode().writeNumIncreasing(-1L).getEncodedBytes() > + * new OrderedCode().writeNumIncreasing(Long.MAX_VALUE).getEncodedBytes()}. + * + * @see #readNumIncreasing() + */ + public void writeNumIncreasing(long value) { + // Values are encoded with a single byte length prefix, followed + // by the actual value in big-endian format with leading 0 bytes + // dropped. + byte[] bufer = new byte[9]; // 8 bytes for value plus one byte for length + int len = 0; + while (value != 0) { + len++; + bufer[9 - len] = (byte) (value & 0xff); + value >>>= 8; + } + bufer[9 - len - 1] = (byte) len; + len++; + byte[] encodedArray = new byte[len]; + System.arraycopy(bufer, 9 - len, encodedArray, 0, len); + encodedArrays.add(encodedArray); + } + + /** + * Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. + */ + int log2Floor(long n) { + if (n < 0) { + throw new IllegalArgumentException("must be non-negative"); + } + return n == 0 ? -1 : LongMath.log2(n, RoundingMode.FLOOR); + } + + /** + * Calculates the encoding length in bytes of the signed number n. + */ + int getSignedEncodingLength(long n) { + return BITS_TO_LENGTH[log2Floor(n < 0 ? ~n : n) + 1]; + } + + /** + * Encodes the long item, in big-endian format, and appends the result to its + * internal encoded byte array store. + * + * Note that the specified long is treated like an int64, i.e. + * {@code new OrderedCode().writeNumIncreasing(-1L).getEncodedBytes() < + * new OrderedCode().writeNumIncreasing(0L).getEncodedBytes()}. + * + * @see #readSignedNumIncreasing() + */ + public void writeSignedNumIncreasing(long val) { + long x = val < 0 ? ~val : val; + if (x < 64) { // Fast path for encoding length == 1. + byte[] encodedArray = + new byte[] { (byte) (LENGTH_TO_HEADER_BITS[1][0] ^ val) }; + encodedArrays.add(encodedArray); + return; + } + // buf = val in network byte order, sign extended to 10 bytes. + byte signByte = val < 0 ? (byte) 0xff : 0; + byte[] buf = new byte[2 + Longs.BYTES]; + buf[0] = buf[1] = signByte; + System.arraycopy(Longs.toByteArray(val), 0, buf, 2, Longs.BYTES); + int len = getSignedEncodingLength(x); + if (len < 2) { + throw new IllegalStateException( + "Invalid length (" + len + ")" + + " returned by getSignedEncodingLength(" + x + ")"); + } + int beginIndex = buf.length - len; + buf[beginIndex] ^= LENGTH_TO_HEADER_BITS[len][0]; + buf[beginIndex + 1] ^= LENGTH_TO_HEADER_BITS[len][1]; + + byte[] encodedArray = new byte[len]; + System.arraycopy(buf, beginIndex, encodedArray, 0, len); + encodedArrays.add(encodedArray); + } + + /** + * Encodes and appends INFINITY item to its internal encoded byte array + * store. + * + * @see #readInfinity() + */ + public void writeInfinity() { + writeTrailingBytes(INFINITY_ENCODED); + } + + /** + * Appends the byte array item to its internal encoded byte array + * store. This is used for the last item and is not encoded. It + * also can be used to write a fixed number of bytes which will be + * read back using {@link #readBytes(int)}. + * + *

+ * It stores the input array in the store, + * so the input array 'value' should not be modified. + * + * @param value bytes to be written. + * @see #readTrailingBytes() + * @see #readBytes(int) + */ + public void writeTrailingBytes(byte[] value) { + if ((value == null) || (value.length == 0)) { + throw new IllegalArgumentException( + "Value cannot be null or have 0 elements"); + } + + encodedArrays.add(value); + } + + /** + * Returns the next byte array item from its encoded byte array store and + * removes the item from the store. + * + * @see #writeBytes(byte[]) + */ + public byte[] readBytes() { + if ((encodedArrays == null) || (encodedArrays.size() == 0) || + ((encodedArrays.get(0)).length - firstArrayPosition <= 0)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + // Determine the length of the decoded array + // We only scan up to "length-2" since a valid string must end with + // a two character terminator: 'ESCAPE1 SEPARATOR' + byte[] store = encodedArrays.get(0); + int decodedLength = 0; + boolean valid = false; + int i = firstArrayPosition; + while (i < store.length - 1) { + byte b = store[i++]; + if (b == ESCAPE1) { + b = store[i++]; + if (b == SEPARATOR) { + valid = true; + break; + } else if (b == NULL_CHARACTER) { + decodedLength++; + } else { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + } else if (b == ESCAPE2) { + b = store[i++]; + if (b == FF_CHARACTER) { + decodedLength++; + } else { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + } else { + decodedLength++; + } + } + if (!valid) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + byte[] decodedArray = new byte[decodedLength]; + int copyStart = firstArrayPosition; + int outIndex = 0; + int j = firstArrayPosition; + while (j < store.length - 1) { + byte b = store[j++]; // note that j has been incremented + if (b == ESCAPE1) { + System.arraycopy(store, copyStart, decodedArray, outIndex, + j - copyStart - 1); + outIndex += j - copyStart - 1; + // ESCAPE1 SEPARATOR ends component + // ESCAPE1 NULL_CHARACTER represents '\0' + b = store[j++]; + if (b == SEPARATOR) { + if ((store.length - j) == 0) { + // we are done with the first array + encodedArrays.remove(0); + firstArrayPosition = 0; + } else { + firstArrayPosition = j; + } + return decodedArray; + } else if (b == NULL_CHARACTER) { + decodedArray[outIndex++] = 0x00; + } // else not required - handled during length determination + copyStart = j; + } else if (b == ESCAPE2) { + System.arraycopy(store, copyStart, decodedArray, outIndex, + j - copyStart - 1); + outIndex += j - copyStart - 1; + // ESCAPE2 FF_CHARACTER represents '\xff' + // ESCAPE2 INFINITY is an error + b = store[j++]; + if (b == FF_CHARACTER) { + decodedArray[outIndex++] = (byte) 0xff; + } // else not required - handled during length determination + copyStart = j; + } + } + // not required due to the first phase, but need to entertain the compiler + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + /** + * Returns the next long item (encoded in big-endian format via + * {@code writeNumIncreasing(long)}) from its internal encoded byte array + * store and removes the item from the store. + * + * @see #writeNumIncreasing(long) + */ + public long readNumIncreasing() { + if ((encodedArrays == null) || (encodedArrays.size() == 0) || + ((encodedArrays.get(0)).length - firstArrayPosition < 1)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + byte[] store = encodedArrays.get(0); + // Decode length byte + int len = store[firstArrayPosition]; + if ((firstArrayPosition + len + 1 > store.length) || len > 8) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + long result = 0; + for (int i = 0; i < len; i++) { + result <<= 8; + result |= (store[firstArrayPosition + i + 1] & 0xff); + } + + if ((store.length - firstArrayPosition - len - 1) == 0) { + // we are done with the first array + encodedArrays.remove(0); + firstArrayPosition = 0; + } else { + firstArrayPosition = firstArrayPosition + len + 1; + } + + return result; + } + + /** + * Returns the next long item (encoded via + * {@code writeSignedNumIncreasing(long)}) from its internal encoded byte + * array store and removes the item from the store. + * + * @see #writeSignedNumIncreasing(long) + */ + public long readSignedNumIncreasing() { + if ((encodedArrays == null) || (encodedArrays.size() == 0) || + ((encodedArrays.get(0)).length - firstArrayPosition < 1)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + byte[] store = encodedArrays.get(0); + + long xorMask = ((store[firstArrayPosition] & 0x80) == 0) ? ~0L : 0L; + // Store first byte as an int rather than a (signed) byte -- to avoid + // accidental byte-to-int promotion later which would extend the byte's + // sign bit (if any). + int firstByte = + (store[firstArrayPosition] & 0xff) ^ (int) (xorMask & 0xff); + + // Now calculate and test length, and set x to raw (unmasked) result. + int len; + long x; + if (firstByte != 0xff) { + len = 7 - log2Floor(firstByte ^ 0xff); + if (store.length - firstArrayPosition < len) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + x = xorMask; // Sign extend using xorMask. + for (int i = firstArrayPosition; i < firstArrayPosition + len; i++) { + x = (x << 8) | (store[i] & 0xff); + } + } else { + len = 8; + if (store.length - firstArrayPosition < len) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + int secondByte = + (store[firstArrayPosition + 1] & 0xff) ^ (int) (xorMask & 0xff); + if (secondByte >= 0x80) { + if (secondByte < 0xc0) { + len = 9; + } else { + int thirdByte = + (store[firstArrayPosition + 2] & 0xff) ^ (int) (xorMask & 0xff); + if (secondByte == 0xc0 && thirdByte < 0x80) { + len = 10; + } else { + // Either len > 10 or len == 10 and #bits > 63. + throw new IllegalArgumentException("Invalid encoded byte array"); + } + } + if (store.length - firstArrayPosition < len) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + } + x = Longs.fromByteArray(Arrays.copyOfRange( + store, firstArrayPosition + len - 8, firstArrayPosition + len)); + } + + x ^= LENGTH_TO_MASK[len]; // Remove spurious header bits. + + if (len != getSignedEncodingLength(x)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + if ((store.length - firstArrayPosition - len) == 0) { + // We are done with the first array. + encodedArrays.remove(0); + firstArrayPosition = 0; + } else { + firstArrayPosition = firstArrayPosition + len; + } + + return x; + } + + /** + * Removes INFINITY item from its internal encoded byte array store + * if present. Returns whether INFINITY was present. + * + * @see #writeInfinity() + */ + public boolean readInfinity() { + if ((encodedArrays == null) || (encodedArrays.size() == 0) || + ((encodedArrays.get(0)).length - firstArrayPosition < 1)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + byte[] store = encodedArrays.get(0); + if (store.length - firstArrayPosition < 2) { + return false; + } + if ((store[firstArrayPosition] == ESCAPE2) && + (store[firstArrayPosition + 1] == INFINITY)) { + if ((store.length - firstArrayPosition - 2) == 0) { + // we are done with the first array + encodedArrays.remove(0); + firstArrayPosition = 0; + } else { + firstArrayPosition = firstArrayPosition + 2; + } + return true; + } else { + return false; + } + } + + /** + * Returns the trailing byte array item from its internal encoded byte array + * store and removes the item from the store. + * + * @see #writeTrailingBytes(byte[]) + */ + public byte[] readTrailingBytes() { + // one item is contained within one byte array + if ((encodedArrays == null) || (encodedArrays.size() != 1)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + byte[] store = encodedArrays.get(0); + encodedArrays.remove(0); + assert encodedArrays.size() == 0; + return Arrays.copyOfRange(store, firstArrayPosition, store.length); + } + + /** + * Reads (unencoded) {@code len} bytes. + * + * @see #writeTrailingBytes(byte[]) + */ + public byte[] readBytes(int len) { + if ((encodedArrays == null) || (encodedArrays.size() == 0) || + ((encodedArrays.get(0)).length - firstArrayPosition < len)) { + throw new IllegalArgumentException("Invalid encoded byte array"); + } + + byte[] store = encodedArrays.get(0); + + byte[] result; + if (store.length - firstArrayPosition == len) { + // We are done with the first array. + result = encodedArrays.remove(0); + firstArrayPosition = 0; + } else { + result = new byte[len]; + System.arraycopy(store, firstArrayPosition, result, 0, len); + firstArrayPosition = firstArrayPosition + len; + } + return result; + } + + /** + * Returns the encoded bytes that represent the current state of the + * OrderedCode. + * + *

+ * NOTE: This method returns OrederedCode's internal array (not a + * copy) for better performance. Therefore the returned array should not be + * modified. + */ + public byte[] getEncodedBytes() { + if (encodedArrays.size() == 0) { + return new byte[0]; + } + if ((encodedArrays.size() == 1) && (firstArrayPosition == 0)) { + return encodedArrays.get(0); + } + + int totalLength = 0; + + for (int i = 0; i < encodedArrays.size(); i++) { + byte[] bytes = encodedArrays.get(i); + if (i == 0) { + totalLength += bytes.length - firstArrayPosition; + } else { + totalLength += bytes.length; + } + } + + byte[] encodedBytes = new byte[totalLength]; + int destPos = 0; + for (int i = 0; i < encodedArrays.size(); i++) { + byte[] bytes = encodedArrays.get(i); + if (i == 0) { + System.arraycopy(bytes, firstArrayPosition, encodedBytes, destPos, + bytes.length - firstArrayPosition); + destPos += bytes.length - firstArrayPosition; + } else { + System.arraycopy(bytes, 0, encodedBytes, destPos, bytes.length); + destPos += bytes.length; + } + } + + // replace the store with merged array, so that repeated calls + // don't need to merge. The reads can handle both the versions. + encodedArrays.clear(); + encodedArrays.add(encodedBytes); + firstArrayPosition = 0; + + return encodedBytes; + } + + /** + * Returns true if this has more encoded bytes that haven't been read, + * false otherwise. Return value of true doesn't imply anything about + * validity of remaining data. + * @return true if it has more encoded bytes that haven't been read, + * false otherwise. + */ + public boolean hasRemainingEncodedBytes() { + // We delete an array after fully consuming it. + return encodedArrays != null && encodedArrays.size() != 0; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactory.java new file mode 100644 index 000000000000..23d4040685bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactory.java @@ -0,0 +1,115 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoFn; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Creates a ParDoFn from a CloudObject spec. + * + * A ParDoFnFactory concrete "subclass" should define a method with + * the following signature: + *

 {@code
+ * static SomeParDoFnSubclass create(
+ *     CloudObject spec,
+ *     List sideInputInfos,
+ *     List multiOutputInfos,
+ *     int numOutputs,
+ *     ExecutionContext executionContext);
+ * } 
+ */ +public class ParDoFnFactory { + // Do not instantiate. + private ParDoFnFactory() {} + + /** + * A map from the short names of predefined ParDoFnFactories to their full + * class names. + */ + static Map predefinedParDoFnFactories = new HashMap<>(); + + static { + predefinedParDoFnFactories.put("DoFn", + NormalParDoFn.class.getName()); + predefinedParDoFnFactories.put("CombineValuesFn", + CombineValuesFn.class.getName()); + // TODO: Remove outdated bindings once the services produces the right ones + predefinedParDoFnFactories.put("MergeBucketsDoFn", + GroupAlsoByWindowsParDoFn.class.getName()); + predefinedParDoFnFactories.put("AssignBucketsDoFn", + AssignWindowsParDoFn.class.getName()); + predefinedParDoFnFactories.put("MergeWindowsDoFn", + GroupAlsoByWindowsParDoFn.class.getName()); + predefinedParDoFnFactories.put("AssignWindowsDoFn", + AssignWindowsParDoFn.class.getName()); + } + + /** + * Creates a ParDoFn from a CloudObject spec. + * + * @throws Exception if the CloudObject spec could not be + * decoded and constructed. + */ + public static ParDoFn create(PipelineOptions options, + CloudObject cloudUserFn, + String stepName, + List sideInputInfos, + List multiOutputInfos, + int numOutputs, + ExecutionContext executionContext, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) + throws Exception { + String className = cloudUserFn.getClassName(); + String parDoFnFactoryClassName = predefinedParDoFnFactories.get(className); + if (parDoFnFactoryClassName == null) { + parDoFnFactoryClassName = className; + } + + try { + return InstanceBuilder.ofType(ParDoFn.class) + .fromClassName(parDoFnFactoryClassName) + .fromFactoryMethod("create") + .withArg(PipelineOptions.class, options) + .withArg(CloudObject.class, cloudUserFn) + .withArg(String.class, stepName) + .withArg(List.class, sideInputInfos) + .withArg(List.class, multiOutputInfos) + .withArg(Integer.class, numOutputs) + .withArg(ExecutionContext.class, executionContext) + .withArg(CounterSet.AddCounterMutator.class, addCounterMutator) + .withArg(StateSampler.class, stateSampler) + .build(); + + } catch (ClassNotFoundException exn) { + throw new Exception( + "unable to create a ParDoFn from " + cloudUserFn, exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSource.java new file mode 100644 index 000000000000..5394a26cc47f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSource.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.BatchingShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.io.IOException; +import java.util.Iterator; + +/** + * A source that reads from a key-sharded dataset, and returns KVs without + * any values grouping. + * + * @param the type of the keys read from the shuffle + * @param the type of the values read from the shuffle + */ +public class PartitioningShuffleSource extends Source>> { + + final byte[] shuffleReaderConfig; + final String startShufflePosition; + final String stopShufflePosition; + Coder keyCoder; + WindowedValueCoder windowedValueCoder; + + public PartitioningShuffleSource(PipelineOptions options, + byte[] shuffleReaderConfig, + String startShufflePosition, + String stopShufflePosition, + Coder>> coder) + throws Exception { + this.shuffleReaderConfig = shuffleReaderConfig; + this.startShufflePosition = startShufflePosition; + this.stopShufflePosition = stopShufflePosition; + initCoder(coder); + } + + /** + * Given a {@code WindowedValueCoder>}, splits it into a coder for K + * and a {@code WindowedValueCoder} with the same kind of windows. + */ + private void initCoder(Coder>> coder) throws Exception { + if (!(coder instanceof WindowedValueCoder)) { + throw new Exception( + "unexpected kind of coder for WindowedValue: " + coder); + } + WindowedValueCoder> windowedElemCoder = ((WindowedValueCoder>) coder); + Coder> elemCoder = windowedElemCoder.getValueCoder(); + if (!(elemCoder instanceof KvCoder)) { + throw new Exception( + "unexpected kind of coder for elements read from " + + "a key-partitioning shuffle: " + elemCoder); + } + KvCoder kvCoder = (KvCoder) elemCoder; + this.keyCoder = kvCoder.getKeyCoder(); + windowedValueCoder = windowedElemCoder.withValueCoder(kvCoder.getValueCoder()); + } + + @Override + public com.google.cloud.dataflow.sdk.util.common.worker.Source.SourceIterator< + WindowedValue>> iterator() throws IOException { + Preconditions.checkArgument(shuffleReaderConfig != null); + return iterator(new BatchingShuffleEntryReader( + new ChunkingShuffleBatchReader(new ApplianceShuffleReader( + shuffleReaderConfig)))); + } + + SourceIterator>> iterator(ShuffleEntryReader reader) throws IOException { + return new PartitioningShuffleSourceIterator(reader); + } + + /** + * A SourceIterator that reads from a ShuffleEntryReader, + * extracts K and {@code WindowedValue}, and returns a constructed + * {@code WindowedValue}. + */ + class PartitioningShuffleSourceIterator + extends AbstractSourceIterator>> { + Iterator iterator; + + PartitioningShuffleSourceIterator(ShuffleEntryReader reader) { + this.iterator = reader.read( + ByteArrayShufflePosition.fromBase64(startShufflePosition), + ByteArrayShufflePosition.fromBase64(stopShufflePosition)); + } + + @Override + public boolean hasNext() throws IOException { + return iterator.hasNext(); + } + + @Override + public WindowedValue> next() throws IOException { + ShuffleEntry record = iterator.next(); + K key = CoderUtils.decodeFromByteArray(keyCoder, record.getKey()); + WindowedValue windowedValue = + CoderUtils.decodeFromByteArray(windowedValueCoder, record.getValue()); + notifyElementRead(record.length()); + return WindowedValue.of(KV.of(key, windowedValue.getValue()), + windowedValue.getTimestamp(), + windowedValue.getWindows()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceFactory.java new file mode 100644 index 000000000000..f97d1d5b8298 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceFactory.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.decodeBase64; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * Creates a PartitioningShuffleSource from a CloudObject spec. + */ +public class PartitioningShuffleSourceFactory { + // Do not instantiate. + private PartitioningShuffleSourceFactory() {} + + public static PartitioningShuffleSource create( + PipelineOptions options, + CloudObject spec, + Coder>> coder, + ExecutionContext executionContext) + throws Exception { + return new PartitioningShuffleSource( + options, + decodeBase64(getString(spec, PropertyNames.SHUFFLE_READER_CONFIG)), + getString(spec, PropertyNames.START_SHUFFLE_POSITION, null), + getString(spec, PropertyNames.END_SHUFFLE_POSITION, null), + coder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleEntryWriter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleEntryWriter.java new file mode 100644 index 000000000000..4fd44230421d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleEntryWriter.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; + +import java.io.IOException; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * ShuffleEntryWriter provides an interface for writing key/value + * entries to a shuffle dataset. + */ +@NotThreadSafe +interface ShuffleEntryWriter extends AutoCloseable { + /** + * Writes an entry to a shuffle dataset. Returns the size + * in bytes of the data written. + */ + public long put(ShuffleEntry entry) throws IOException; + + @Override + public void close() throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleLibrary.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleLibrary.java new file mode 100644 index 000000000000..8863436d2e1d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleLibrary.java @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; + +/** + * Native library used to read from and write to a shuffle dataset. + */ +class ShuffleLibrary { + /** + * Loads the native shuffle library. + */ + static void load() { + try { + File tempfile = File.createTempFile("libshuffle_client_jni", ".so"); + InputStream input = ClassLoader.getSystemResourceAsStream( + "libshuffle_client_jni.so.stripped"); + Files.copy(input, tempfile.toPath(), StandardCopyOption.REPLACE_EXISTING); + System.load(tempfile.getAbsolutePath()); + } catch (IOException e) { + throw new RuntimeException("Loading shuffle_client failed:", e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleReader.java new file mode 100644 index 000000000000..8a1018b237ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleReader.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import java.io.IOException; + +/** + * ShuffleReader reads chunks of data from a shuffle dataset for + * a given position range. + */ +interface ShuffleReader { + /** Represents a chunk of data read from a shuffle dataset. */ + public static class ReadChunkResult { + public final byte[] chunk; + public final byte[] nextStartPosition; + public ReadChunkResult(byte[] chunk, byte[] nextStartPosition) { + this.chunk = chunk; + this.nextStartPosition = nextStartPosition; + } + } + + /** + * Reads a chunk of data for keys in the given position range. + * The chunk is a sequence of pairs encoded as: + * {@code + } + * where the sizes are 4-byte big-endian integers. + * + * @param startPosition the start of the requested range (inclusive) + * @param endPosition the end of the requested range (exclusive) + */ + public ReadChunkResult readIncludingPosition( + byte[] startPosition, byte[] endPosition) throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSink.java new file mode 100644 index 000000000000..72ea16fc99b4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSink.java @@ -0,0 +1,248 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.io.IOException; + +/** + * A sink that writes to a shuffle dataset. + * + * @param the type of the elements written to the sink + */ +public class ShuffleSink extends Sink> { + + enum ShuffleKind { UNGROUPED, PARTITION_KEYS, GROUP_KEYS, GROUP_KEYS_AND_SORT_VALUES } + + static final long SHUFFLE_WRITER_BUFFER_SIZE = 128 << 20; + + final byte[] shuffleWriterConfig; + + final ShuffleKind shuffleKind; + + boolean shardByKey; + boolean groupValues; + boolean sortValues; + + WindowedValueCoder windowedElemCoder; + WindowedValueCoder windowedValueCoder; + Coder elemCoder; + Coder keyCoder; + Coder valueCoder; + Coder sortKeyCoder; + Coder sortValueCoder; + + public static ShuffleKind parseShuffleKind(String shuffleKind) + throws Exception { + try { + return Enum.valueOf(ShuffleKind.class, shuffleKind.trim().toUpperCase()); + } catch (IllegalArgumentException e) { + throw new Exception("unexpected shuffle_kind", e); + } + } + + public ShuffleSink(PipelineOptions options, + byte[] shuffleWriterConfig, + ShuffleKind shuffleKind, + Coder> coder) + throws Exception { + this.shuffleWriterConfig = shuffleWriterConfig; + this.shuffleKind = shuffleKind; + initCoder(coder); + } + + private void initCoder(Coder> coder) throws Exception { + switch (shuffleKind) { + case UNGROUPED: + this.shardByKey = false; + this.groupValues = false; + this.sortValues = false; + break; + case PARTITION_KEYS: + this.shardByKey = true; + this.groupValues = false; + this.sortValues = false; + break; + case GROUP_KEYS: + this.shardByKey = true; + this.groupValues = true; + this.sortValues = false; + break; + case GROUP_KEYS_AND_SORT_VALUES: + this.shardByKey = true; + this.groupValues = true; + this.sortValues = true; + break; + default: + throw new AssertionError("unexpected shuffle kind"); + } + + this.windowedElemCoder = (WindowedValueCoder) coder; + this.elemCoder = windowedElemCoder.getValueCoder(); + if (shardByKey) { + if (!(elemCoder instanceof KvCoder)) { + throw new Exception( + "unexpected kind of coder for elements written to " + + "a key-grouping shuffle"); + } + KvCoder kvCoder = (KvCoder) elemCoder; + this.keyCoder = kvCoder.getKeyCoder(); + this.valueCoder = kvCoder.getValueCoder(); + if (sortValues) { + // TODO: Decide the representation of sort-keyed values. + // For now, we'll just use KVs. + if (!(valueCoder instanceof KvCoder)) { + throw new Exception( + "unexpected kind of coder for values written to " + + "a value-sorting shuffle"); + } + KvCoder kvValueCoder = (KvCoder) valueCoder; + this.sortKeyCoder = kvValueCoder.getKeyCoder(); + this.sortValueCoder = kvValueCoder.getValueCoder(); + } else { + this.sortKeyCoder = null; + this.sortValueCoder = null; + } + if (groupValues) { + this.windowedValueCoder = null; + } else { + this.windowedValueCoder = this.windowedElemCoder.withValueCoder(this.valueCoder); + } + } else { + this.keyCoder = null; + this.valueCoder = null; + this.sortKeyCoder = null; + this.sortValueCoder = null; + this.windowedValueCoder = null; + } + } + + /** + * Returns a SinkWriter that allows writing to this ShuffleSink, + * using the given ShuffleEntryWriter. + */ + public SinkWriter> writer(ShuffleEntryWriter writer) throws IOException { + return new ShuffleSinkWriter(writer); + } + + /** The SinkWriter for a ShuffleSink. */ + class ShuffleSinkWriter implements SinkWriter> { + ShuffleEntryWriter writer; + long seqNum = 0; + + ShuffleSinkWriter(ShuffleEntryWriter writer) throws IOException { + this.writer = writer; + } + + @Override + public long add(WindowedValue windowedElem) throws IOException { + byte[] keyBytes; + byte[] secondaryKeyBytes; + byte[] valueBytes; + T elem = windowedElem.getValue(); + if (shardByKey) { + if (!(elem instanceof KV)) { + throw new AssertionError( + "expecting the values written to a key-grouping shuffle " + + "to be KVs"); + } + KV kv = (KV) elem; + Object key = kv.getKey(); + Object value = kv.getValue(); + + keyBytes = CoderUtils.encodeToByteArray(keyCoder, key); + + if (sortValues) { + if (!(value instanceof KV)) { + throw new AssertionError( + "expecting the value parts of the KVs written to " + + "a value-sorting shuffle to also be KVs"); + } + KV kvValue = (KV) value; + Object sortKey = kvValue.getKey(); + Object sortValue = kvValue.getValue(); + + // TODO: Need to coordinate with the + // GroupingShuffleSource, to make sure it knows how to + // reconstruct the value from the sortKeyBytes and + // sortValueBytes. Right now, it doesn't know between + // sorting and non-sorting GBKs. + secondaryKeyBytes = + CoderUtils.encodeToByteArray(sortKeyCoder, sortKey); + valueBytes = CoderUtils.encodeToByteArray(sortValueCoder, sortValue); + + } else if (groupValues) { + // Sort values by timestamp so that GroupAlsoByWindows can run efficiently. + if (windowedElem.getTimestamp().getMillis() == Long.MIN_VALUE) { + // Empty secondary keys sort before all other secondary keys, so we + // can omit this common value here for efficiency. + secondaryKeyBytes = null; + } else { + secondaryKeyBytes = + CoderUtils.encodeToByteArray(InstantCoder.of(), windowedElem.getTimestamp()); + } + valueBytes = CoderUtils.encodeToByteArray(valueCoder, value); + } else { + secondaryKeyBytes = null; + valueBytes = CoderUtils.encodeToByteArray( + windowedValueCoder, + WindowedValue.of(value, windowedElem.getTimestamp(), windowedElem.getWindows())); + } + + } else { + // Not partitioning or grouping by key, just resharding values. + // is ignored, except by the shuffle splitter. Use a seq# + // as the key, so we can split records anywhere. This also works + // for writing a single-sharded ordered PCollection through a + // shuffle, since the order of elements in the input will be + // preserved in the output. + keyBytes = + CoderUtils.encodeToByteArray(BigEndianLongCoder.of(), seqNum++); + + secondaryKeyBytes = null; + valueBytes = CoderUtils.encodeToByteArray(windowedElemCoder, windowedElem); + } + + return writer.put(new ShuffleEntry( + keyBytes, secondaryKeyBytes, valueBytes)); + } + + @Override + public void close() throws IOException { + writer.close(); + } + } + + @Override + public SinkWriter> writer() throws IOException { + Preconditions.checkArgument(shuffleWriterConfig != null); + return writer(new ChunkingShuffleEntryWriter(new ApplianceShuffleWriter( + shuffleWriterConfig, SHUFFLE_WRITER_BUFFER_SIZE))); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactory.java new file mode 100644 index 000000000000..6db9945eb613 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.decodeBase64; +import static com.google.cloud.dataflow.sdk.runners.worker.ShuffleSink.parseShuffleKind; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +/** + * Creates a ShuffleSink from a CloudObject spec. + */ +public class ShuffleSinkFactory { + // Do not instantiate. + private ShuffleSinkFactory() {} + + public static ShuffleSink create(PipelineOptions options, + CloudObject spec, + Coder> coder, + ExecutionContext executionContext) + throws Exception { + return create(options, spec, coder); + } + + static ShuffleSink create(PipelineOptions options, + CloudObject spec, + Coder> coder) + throws Exception { + return new ShuffleSink<>( + options, + decodeBase64(getString(spec, PropertyNames.SHUFFLE_WRITER_CONFIG, null)), + parseShuffleKind(getString(spec, PropertyNames.SHUFFLE_KIND)), + coder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleWriter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleWriter.java new file mode 100644 index 000000000000..ff880fd13c4c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleWriter.java @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import java.io.IOException; + +/** + * ShuffleWriter writes chunks of records to a shuffle dataset. + */ +interface ShuffleWriter extends AutoCloseable { + /** + * Writes a chunk of records. The chunk is a sequence of pairs encoded as: + * + * where the sizes are 4-byte big-endian integers. + */ + public void write(byte[] chunk) throws IOException; + + /** + * Flushes written records and closes this writer. + */ + @Override + public void close() throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtils.java new file mode 100644 index 000000000000..f3fc1cf3f3ef --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtils.java @@ -0,0 +1,211 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * Utilities for working with side inputs. + */ +public class SideInputUtils { + static final String SINGLETON_KIND = "singleton"; + static final String COLLECTION_KIND = "collection"; + + /** + * Reads the given side input, producing the contents associated + * with a a {@link PCollectionView}. + */ + public static Object readSideInput(PipelineOptions options, + SideInputInfo sideInputInfo, + ExecutionContext executionContext) + throws Exception { + Iterable elements = + readSideInputSources(options, sideInputInfo.getSources(), executionContext); + return readSideInputValue(sideInputInfo.getKind(), elements); + } + + static Iterable readSideInputSources( + PipelineOptions options, + List sideInputSources, + ExecutionContext executionContext) + throws Exception { + int numSideInputSources = sideInputSources.size(); + if (numSideInputSources == 0) { + throw new Exception("expecting at least one side input Source"); + } else if (numSideInputSources == 1) { + return readSideInputSource(options, sideInputSources.get(0), executionContext); + } else { + List> shards = new ArrayList<>(); + for (com.google.api.services.dataflow.model.Source sideInputSource + : sideInputSources) { + shards.add(readSideInputSource(options, sideInputSource, executionContext)); + } + return new ShardedIterable<>(shards); + } + } + + static Iterable readSideInputSource( + PipelineOptions options, + com.google.api.services.dataflow.model.Source sideInputSource, + ExecutionContext executionContext) + throws Exception { + return new SourceIterable<>( + SourceFactory.create(options, sideInputSource, executionContext)); + } + + static Object readSideInputValue(Map sideInputKind, + Iterable elements) + throws Exception { + String className = getString(sideInputKind, PropertyNames.OBJECT_TYPE_NAME); + if (SINGLETON_KIND.equals(className)) { + Iterator iter = elements.iterator(); + if (iter.hasNext()) { + Object elem = iter.next(); + if (!iter.hasNext()) { + return elem; + } + } + throw new Exception( + "expecting a singleton side input to have a single value"); + + } else if (COLLECTION_KIND.equals(className)) { + return elements; + + } else { + throw new Exception("unexpected kind of side input: " + className); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + + static class SourceIterable implements Iterable { + final Source source; + + public SourceIterable(Source source) { + this.source = source; + } + + @Override + public Iterator iterator() { + try { + return new SourceIterator<>(source.iterator()); + } catch (Exception exn) { + throw new RuntimeException(exn); + } + } + } + + static class SourceIterator implements Iterator { + final Source.SourceIterator iterator; + + public SourceIterator(Source.SourceIterator iterator) { + this.iterator = iterator; + } + + @Override + public boolean hasNext() { + try { + return iterator.hasNext(); + } catch (Exception exn) { + throw new RuntimeException(exn); + } + } + + @Override + public T next() { + try { + return iterator.next(); + } catch (Exception exn) { + throw new RuntimeException(exn); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static class ShardedIterable implements Iterable { + final List> shards; + + public ShardedIterable(List> shards) { + this.shards = shards; + } + + @Override + public Iterator iterator() { + return new ShardedIterator<>(shards.iterator()); + } + } + + static class ShardedIterator implements Iterator { + final Iterator> shards; + Iterator shard; + + public ShardedIterator(Iterator> shards) { + this.shards = shards; + this.shard = null; + } + + @Override + public boolean hasNext() { + boolean shardHasNext; + for (;;) { + shardHasNext = (shard != null && shard.hasNext()); + if (shardHasNext) { + break; + } + if (!shards.hasNext()) { + break; + } + shard = shards.next().iterator(); + } + return shardHasNext; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return shard.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactory.java new file mode 100644 index 000000000000..df2d5ac75428 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactory.java @@ -0,0 +1,94 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.common.reflect.TypeToken; + +import java.util.HashMap; +import java.util.Map; + +/** + * Constructs a Sink from a Dataflow service protocol Sink definition. + * + * A SinkFactory concrete "subclass" should define a method with the following + * signature: + *
 {@code
+ * static SomeSinkSubclass create(PipelineOptions, CloudObject,
+ *                                   Coder, ExecutionContext);
+ * } 
+ */ +public final class SinkFactory { + // Do not instantiate. + private SinkFactory() {} + + /** + * A map from the short names of predefined sinks to their full + * factory class names. + */ + static Map predefinedSinkFactories = new HashMap<>(); + + static { + predefinedSinkFactories.put("TextSink", + TextSinkFactory.class.getName()); + predefinedSinkFactories.put("AvroSink", + AvroSinkFactory.class.getName()); + predefinedSinkFactories.put("ShuffleSink", + ShuffleSinkFactory.class.getName()); + } + + /** + * Creates a {@link Sink} from a Dataflow API Sink definition. + * + * @throws Exception if the sink could not be decoded and + * constructed + */ + public static Sink create( + PipelineOptions options, + com.google.api.services.dataflow.model.Sink cloudSink, + ExecutionContext executionContext) + throws Exception { + Coder coder = Serializer.deserialize(cloudSink.getCodec(), Coder.class); + CloudObject object = CloudObject.fromSpec(cloudSink.getSpec()); + + String className = predefinedSinkFactories.get(object.getClassName()); + if (className == null) { + className = object.getClassName(); + } + + try { + return InstanceBuilder.ofType(new TypeToken>() {}) + .fromClassName(className) + .fromFactoryMethod("create") + .withArg(PipelineOptions.class, options) + .withArg(CloudObject.class, object) + .withArg(Coder.class, coder) + .withArg(ExecutionContext.class, executionContext) + .build(); + + } catch (ClassNotFoundException exn) { + throw new Exception( + "unable to create a sink from " + cloudSink, exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactory.java new file mode 100644 index 000000000000..d4726094a3ea --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactory.java @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CloudSourceUtils; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.common.reflect.TypeToken; + +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Constructs a Source from a Dataflow API Source definition. + * + * A SourceFactory concrete "subclass" should define a method with the following + * signature: + *
 {@code
+ * static SomeSourceSubclass create(PipelineOptions, CloudObject,
+ *                                     Coder, ExecutionContext);
+ * } 
+ */ +public final class SourceFactory { + // Do not instantiate. + private SourceFactory() {} + + /** + * A map from the short names of predefined sources to + * their full factory class names. + */ + static Map predefinedSourceFactories = new HashMap<>(); + + static { + predefinedSourceFactories.put( + "TextSource", + TextSourceFactory.class.getName()); + predefinedSourceFactories.put( + "AvroSource", + AvroSourceFactory.class.getName()); + predefinedSourceFactories.put( + "UngroupedShuffleSource", + UngroupedShuffleSourceFactory.class.getName()); + predefinedSourceFactories.put( + "PartitioningShuffleSource", + PartitioningShuffleSourceFactory.class.getName()); + predefinedSourceFactories.put( + "GroupingShuffleSource", + GroupingShuffleSourceFactory.class.getName()); + predefinedSourceFactories.put( + "InMemorySource", + InMemorySourceFactory.class.getName()); + predefinedSourceFactories.put( + "BigQuerySource", + BigQuerySourceFactory.class.getName()); + } + + /** + * Creates a Source from a Dataflow API Source definition. + * + * @throws Exception if the source could not be decoded and + * constructed + */ + public static Source create( + @Nullable PipelineOptions options, + com.google.api.services.dataflow.model.Source cloudSource, + @Nullable ExecutionContext executionContext) + throws Exception { + cloudSource = CloudSourceUtils.flattenBaseSpecs(cloudSource); + Coder coder = Serializer.deserialize(cloudSource.getCodec(), Coder.class); + CloudObject object = CloudObject.fromSpec(cloudSource.getSpec()); + + String sourceFactoryClassName = predefinedSourceFactories.get(object.getClassName()); + if (sourceFactoryClassName == null) { + sourceFactoryClassName = object.getClassName(); + } + + try { + return InstanceBuilder.ofType(new TypeToken>() {}) + .fromClassName(sourceFactoryClassName) + .fromFactoryMethod("create") + .withArg(PipelineOptions.class, options) + .withArg(CloudObject.class, object) + .withArg(Coder.class, coder) + .withArg(ExecutionContext.class, executionContext) + .build(); + + } catch (ClassNotFoundException exn) { + throw new Exception( + "unable to create a source from " + cloudSource, exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutor.java new file mode 100644 index 000000000000..2db18b272474 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutor.java @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudSourceOperationRequestToSourceOperationRequest; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceOperationResponseToCloudSourceOperationResponse; + +import com.google.api.services.dataflow.model.Source; +import com.google.api.services.dataflow.model.SourceOperationRequest; +import com.google.api.services.dataflow.model.SourceOperationResponse; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.MapTaskExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.WorkExecutor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An executor for a source operation, defined by a {@code SourceOperationRequest}. + */ +public class SourceOperationExecutor extends WorkExecutor { + private static final Logger LOG = LoggerFactory.getLogger(MapTaskExecutor.class); + + private final SourceOperationRequest request; + private SourceOperationResponse response; + + public SourceOperationExecutor(SourceOperationRequest request, + CounterSet counters) { + super(counters); + this.request = request; + } + + @Override + public void execute() throws Exception { + LOG.debug("Executing source operation"); + + Source sourceSpec; + if (request.getGetMetadata() != null) { + sourceSpec = request.getGetMetadata().getSource(); + } else if (request.getSplit() != null) { + sourceSpec = request.getSplit().getSource(); + } else { + throw new UnsupportedOperationException("Unknown source operation"); + } + + this.response = + sourceOperationResponseToCloudSourceOperationResponse( + CustomSourceFormatFactory.create(sourceSpec) + .performSourceOperation( + cloudSourceOperationRequestToSourceOperationRequest(request))); + + LOG.debug("Source operation execution complete"); + } + + public SourceOperationResponse getResponse() { + return response; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutorFactory.java new file mode 100644 index 000000000000..10c862e46487 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceOperationExecutorFactory.java @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.services.dataflow.model.SourceOperationRequest; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * Creates a SourceOperationExecutor from a SourceOperation. + */ +public class SourceOperationExecutorFactory { + public static SourceOperationExecutor create(SourceOperationRequest request) + throws Exception { + CounterSet counters = new CounterSet(); + return new SourceOperationExecutor(request, counters); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceTranslationUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceTranslationUtils.java new file mode 100644 index 000000000000..1e0c8aa23491 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/SourceTranslationUtils.java @@ -0,0 +1,189 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getDictionary; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.Position; +import com.google.api.services.dataflow.model.SourceMetadata; +import com.google.api.services.dataflow.model.SourceOperationRequest; +import com.google.api.services.dataflow.model.SourceOperationResponse; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.worker.CustomSourceFormat; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Utilities for representing Source-specific objects + * using Dataflow model protos. + */ +public class SourceTranslationUtils { + public static Source.Progress cloudProgressToSourceProgress( + @Nullable ApproximateProgress cloudProgress) { + return cloudProgress == null ? null + : new DataflowSourceProgress(cloudProgress); + } + + public static Source.Position cloudPositionToSourcePosition( + @Nullable Position cloudPosition) { + return cloudPosition == null ? null + : new DataflowSourcePosition(cloudPosition); + } + + public static CustomSourceFormat.SourceOperationRequest + cloudSourceOperationRequestToSourceOperationRequest( + @Nullable SourceOperationRequest request) { + return request == null ? null + : new DataflowSourceOperationRequest(request); + } + + public static CustomSourceFormat.SourceOperationResponse + cloudSourceOperationResponseToSourceOperationResponse( + @Nullable SourceOperationResponse response) { + return response == null ? null + : new DataflowSourceOperationResponse(response); + } + + public static CustomSourceFormat.SourceSpec cloudSourceToSourceSpec( + @Nullable com.google.api.services.dataflow.model.Source cloudSource) { + return cloudSource == null ? null + : new DataflowSourceSpec(cloudSource); + } + + public static ApproximateProgress sourceProgressToCloudProgress( + @Nullable Source.Progress sourceProgress) { + return sourceProgress == null ? null + : ((DataflowSourceProgress) sourceProgress).cloudProgress; + } + + public static Position sourcePositionToCloudPosition( + @Nullable Source.Position sourcePosition) { + return sourcePosition == null ? null + : ((DataflowSourcePosition) sourcePosition).cloudPosition; + } + + public static SourceOperationRequest + sourceOperationRequestToCloudSourceOperationRequest( + @Nullable CustomSourceFormat.SourceOperationRequest request) { + return (request == null) ? null + : ((DataflowSourceOperationRequest) request).cloudRequest; + } + + public static SourceOperationResponse + sourceOperationResponseToCloudSourceOperationResponse( + @Nullable CustomSourceFormat.SourceOperationResponse response) { + return (response == null) ? null + : ((DataflowSourceOperationResponse) response).cloudResponse; + } + + public static com.google.api.services.dataflow.model.Source sourceSpecToCloudSource( + @Nullable CustomSourceFormat.SourceSpec spec) { + return (spec == null) ? null + : ((DataflowSourceSpec) spec).cloudSource; + } + + static class DataflowSourceProgress implements Source.Progress { + public final ApproximateProgress cloudProgress; + public DataflowSourceProgress(ApproximateProgress cloudProgress) { + this.cloudProgress = cloudProgress; + } + } + + static class DataflowSourcePosition implements Source.Position { + public final Position cloudPosition; + public DataflowSourcePosition(Position cloudPosition) { + this.cloudPosition = cloudPosition; + } + } + + static class DataflowSourceOperationRequest implements CustomSourceFormat.SourceOperationRequest { + public final SourceOperationRequest cloudRequest; + public DataflowSourceOperationRequest(SourceOperationRequest cloudRequest) { + this.cloudRequest = cloudRequest; + } + } + + static class DataflowSourceOperationResponse + implements CustomSourceFormat.SourceOperationResponse { + public final SourceOperationResponse cloudResponse; + public DataflowSourceOperationResponse(SourceOperationResponse cloudResponse) { + this.cloudResponse = cloudResponse; + } + } + + static class DataflowSourceSpec implements CustomSourceFormat.SourceSpec { + public final com.google.api.services.dataflow.model.Source cloudSource; + public DataflowSourceSpec(com.google.api.services.dataflow.model.Source cloudSource) { + this.cloudSource = cloudSource; + } + } + + // Represents a cloud Source as a dictionary for encoding inside the CUSTOM_SOURCE + // property of CloudWorkflowStep.input. + public static Map cloudSourceToDictionary( + com.google.api.services.dataflow.model.Source source) { + // Do not translate encoding - the source's encoding is translated elsewhere + // to the step's output info. + Map res = new HashMap<>(); + addDictionary(res, PropertyNames.CUSTOM_SOURCE_SPEC, source.getSpec()); + if (source.getMetadata() != null) { + addDictionary(res, PropertyNames.CUSTOM_SOURCE_METADATA, + cloudSourceMetadataToDictionary(source.getMetadata())); + } + if (source.getDoesNotNeedSplitting() != null) { + addBoolean(res, PropertyNames.CUSTOM_SOURCE_DOES_NOT_NEED_SPLITTING, + source.getDoesNotNeedSplitting()); + } + return res; + } + + private static Map cloudSourceMetadataToDictionary( + SourceMetadata metadata) { + Map res = new HashMap<>(); + if (metadata.getProducesSortedKeys() != null) { + addBoolean(res, PropertyNames.CUSTOM_SOURCE_PRODUCES_SORTED_KEYS, + metadata.getProducesSortedKeys()); + } + if (metadata.getEstimatedSizeBytes() != null) { + addLong(res, PropertyNames.CUSTOM_SOURCE_ESTIMATED_SIZE_BYTES, + metadata.getEstimatedSizeBytes()); + } + if (metadata.getInfinite() != null) { + addBoolean(res, PropertyNames.CUSTOM_SOURCE_IS_INFINITE, + metadata.getInfinite()); + } + return res; + } + + public static com.google.api.services.dataflow.model.Source dictionaryToCloudSource( + Map params) throws Exception { + com.google.api.services.dataflow.model.Source res = + new com.google.api.services.dataflow.model.Source(); + res.setSpec(getDictionary(params, PropertyNames.CUSTOM_SOURCE_SPEC)); + // CUSTOM_SOURCE_METADATA and CUSTOM_SOURCE_DOES_NOT_NEED_SPLITTING do not have to be + // translated, because they only make sense in cloud Source objects produced by the user. + return res; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSink.java new file mode 100644 index 000000000000..5fef80f72513 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSink.java @@ -0,0 +1,285 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.util.ShardingWritableByteChannel; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Random; + +import javax.annotation.Nullable; + +/** + * A sink that writes text files. + * + * @param the type of the elements written to the sink + */ +public class TextSink extends Sink { + + static final byte[] NEWLINE = getNewline(); + + private static byte[] getNewline() { + String newline = "\n"; + try { + return newline.getBytes("UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("UTF-8 not supported", e); + } + } + + final String namePrefix; + final String shardFormat; + final String nameSuffix; + final int shardCount; + final boolean appendTrailingNewlines; + final String header; + final String footer; + final Coder coder; + + /** + * For testing only. + * + *

Used by simple tests which write to a single unsharded file. + */ + public static TextSink> createForTest( + String filename, + boolean appendTrailingNewlines, + @Nullable String header, + @Nullable String footer, + Coder coder) { + return create(filename, + "", + "", + 1, + appendTrailingNewlines, + header, + footer, + WindowedValue.getValueOnlyCoder(coder)); + } + + /** + * For DirectPipelineRunner only. + * It wraps the coder with {@code WindowedValue.ValueOnlyCoder}. + */ + public static TextSink> createForDirectPipelineRunner( + String filenamePrefix, + String shardFormat, + String filenameSuffix, + int shardCount, + boolean appendTrailingNewlines, + @Nullable String header, + @Nullable String footer, + Coder coder) { + return create(filenamePrefix, + shardFormat, + filenameSuffix, + shardCount, + appendTrailingNewlines, + header, + footer, + WindowedValue.getValueOnlyCoder(coder)); + } + + /** + * Constructs a new TextSink. + * + * @param filenamePrefix the prefix of output filenames. + * @param shardFormat the shard name template to use for output filenames. + * @param filenameSuffix the suffix of output filenames. + * @param shardCount the number of outupt shards to produce. + * @param appendTrailingNewlines true to append newlines to each output line. + * @param header text to place at the beginning of each output file. + * @param footer text to place at the end of each output file. + * @param coder the code used to encode elements for output. + */ + public static TextSink create(String filenamePrefix, + String shardFormat, + String filenameSuffix, + int shardCount, + boolean appendTrailingNewlines, + @Nullable String header, + @Nullable String footer, + Coder coder) { + return new TextSink<>(filenamePrefix, + shardFormat, + filenameSuffix, + shardCount, + appendTrailingNewlines, + header, + footer, + coder); + } + + private TextSink(String filenamePrefix, + String shardFormat, + String filenameSuffix, + int shardCount, + boolean appendTrailingNewlines, + @Nullable String header, + @Nullable String footer, + Coder coder) { + this.namePrefix = filenamePrefix; + this.shardFormat = shardFormat; + this.nameSuffix = filenameSuffix; + this.shardCount = shardCount; + this.appendTrailingNewlines = appendTrailingNewlines; + this.header = header; + this.footer = footer; + this.coder = coder; + } + + @Override + public SinkWriter writer() throws IOException { + String mimeType; + + if (!(coder instanceof WindowedValueCoder)) { + throw new IOException( + "Expected WindowedValueCoder for inputCoder, got: " + + coder.getClass().getName()); + } + Coder valueCoder = ((WindowedValueCoder) coder).getValueCoder(); + if (valueCoder.equals(StringUtf8Coder.of())) { + mimeType = MimeTypes.TEXT; + } else { + mimeType = MimeTypes.BINARY; + } + + WritableByteChannel writer = IOChannelUtils.create(namePrefix, shardFormat, + nameSuffix, shardCount, mimeType); + + if (writer instanceof ShardingWritableByteChannel) { + return new ShardingTextFileWriter((ShardingWritableByteChannel) writer); + } else { + return new TextFileWriter(writer); + } + } + + /** + * Abstract SinkWriter base class shared by sharded and unsharded Text + * writer implementations. + */ + abstract class AbstractTextFileWriter implements SinkWriter { + protected void init() throws IOException { + if (header != null) { + printLine(ShardingWritableByteChannel.ALL_SHARDS, + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), header)); + } + } + + /** + * Adds a value to the sink. Returns the size in bytes of the data written. + * The return value does -not- include header/footer size. + */ + @Override + public long add(T value) throws IOException { + return printLine(getShardNum(value), + CoderUtils.encodeToByteArray(coder, value)); + } + + @Override + public void close() throws IOException { + if (footer != null) { + printLine(ShardingWritableByteChannel.ALL_SHARDS, + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), footer)); + } + } + + protected long printLine(int shardNum, byte[] line) throws IOException { + long length = line.length; + write(shardNum, ByteBuffer.wrap(line)); + + if (appendTrailingNewlines) { + write(shardNum, ByteBuffer.wrap(NEWLINE)); + length += NEWLINE.length; + } + + return length; + } + + protected abstract void write(int shardNum, ByteBuffer buf) + throws IOException; + protected abstract int getShardNum(T value); + } + + /** An unsharded SinkWriter for a TextSink. */ + class TextFileWriter extends AbstractTextFileWriter { + private final WritableByteChannel outputChannel; + + TextFileWriter(WritableByteChannel outputChannel) throws IOException { + this.outputChannel = outputChannel; + init(); + } + + @Override + public void close() throws IOException { + super.close(); + outputChannel.close(); + } + + @Override + protected void write(int shardNum, ByteBuffer buf) throws IOException { + outputChannel.write(buf); + } + + @Override + protected int getShardNum(T value) { + return 0; + } + } + + /** A sharding SinkWriter for a TextSink. */ + class ShardingTextFileWriter extends AbstractTextFileWriter { + private final Random rng = new Random(); + private final int numShards; + private final ShardingWritableByteChannel outputChannel; + + // TODO: add support for user-defined sharding function. + ShardingTextFileWriter(ShardingWritableByteChannel outputChannel) + throws IOException { + this.outputChannel = outputChannel; + numShards = outputChannel.getNumShards(); + init(); + } + + @Override + public void close() throws IOException { + super.close(); + outputChannel.close(); + } + + @Override + protected void write(int shardNum, ByteBuffer buf) throws IOException { + outputChannel.writeToShard(shardNum, buf); + } + + @Override + protected int getShardNum(T value) { + return rng.nextInt(numShards); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactory.java new file mode 100644 index 000000000000..bac663dea2da --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +/** + * Creates a TextSink from a CloudObject spec. + */ +public final class TextSinkFactory { + // Do not instantiate. + private TextSinkFactory() {} + + public static TextSink create(PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(spec, coder); + } + + static TextSink create(CloudObject spec, Coder coder) + throws Exception { + return TextSink.create( + getString(spec, PropertyNames.FILENAME), + "", // No shard template + "", // No suffix + 1, // Exactly one output file + getBoolean(spec, PropertyNames.APPEND_TRAILING_NEWLINES, true), + getString(spec, PropertyNames.HEADER, null), + getString(spec, PropertyNames.FOOTER, null), + coder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSource.java new file mode 100644 index 000000000000..5bbcba0e6b91 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSource.java @@ -0,0 +1,383 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.common.worker.ProgressTracker; +import com.google.cloud.dataflow.sdk.util.common.worker.ProgressTrackerGroup; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PushbackInputStream; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.util.Collection; +import java.util.Iterator; + +import javax.annotation.Nullable; + +/** + * A source that reads text files. + * + * @param the type of the elements read from the source + */ +public class TextSource extends FileBasedSource { + final boolean stripTrailingNewlines; + + public TextSource(String filename, + boolean stripTrailingNewlines, + @Nullable Long startPosition, + @Nullable Long endPosition, + Coder coder) { + this(filename, stripTrailingNewlines, + startPosition, endPosition, coder, true); + } + + protected TextSource(String filename, + boolean stripTrailingNewlines, + @Nullable Long startPosition, + @Nullable Long endPosition, + Coder coder, + boolean useDefaultBufferSize) { + super(filename, startPosition, endPosition, coder, useDefaultBufferSize); + this.stripTrailingNewlines = stripTrailingNewlines; + } + + @Override + protected SourceIterator newSourceIteratorForRangeInFile( + IOChannelFactory factory, String oneFile, long startPosition, + @Nullable Long endPosition) + throws IOException { + // Position before the first record, so we can find the record beginning. + final long start = startPosition > 0 ? startPosition - 1 : 0; + + TextFileIterator iterator = newSourceIteratorForRangeWithStrictStart( + factory, oneFile, stripTrailingNewlines, start, endPosition); + + // Skip the initial record if start position was set. + if (startPosition > 0 && iterator.hasNext()) { + iterator.advance(); + } + + return iterator; + } + + @Override + protected SourceIterator newSourceIteratorForFiles( + IOChannelFactory factory, Collection files) throws IOException { + if (files.size() == 1) { + return newSourceIteratorForFile( + factory, files.iterator().next(), stripTrailingNewlines); + } + + return new TextFileMultiIterator( + factory, files.iterator(), stripTrailingNewlines); + } + + private TextFileIterator newSourceIteratorForFile( + IOChannelFactory factory, String input, boolean stripTrailingNewlines) + throws IOException { + return newSourceIteratorForRangeWithStrictStart( + factory, input, stripTrailingNewlines, 0, null); + } + + /** + * Returns a new iterator for lines in the given range in the given + * file. Does NOT skip the first line if the range starts in the + * middle of a line (instead, the latter half that starts at + * startOffset will be returned as the first element). + */ + private TextFileIterator newSourceIteratorForRangeWithStrictStart( + IOChannelFactory factory, String input, boolean stripTrailingNewlines, + long startOffset, @Nullable Long endOffset) throws IOException { + ReadableByteChannel reader = factory.open(input); + if (!(reader instanceof SeekableByteChannel)) { + throw new UnsupportedOperationException( + "Unable to seek in stream for " + input); + } + + SeekableByteChannel seeker = (SeekableByteChannel) reader; + + return new TextFileIterator( + new CopyableSeekableByteChannel(seeker), + stripTrailingNewlines, startOffset, endOffset); + } + + class TextFileMultiIterator extends LazyMultiSourceIterator { + private final IOChannelFactory factory; + private final boolean stripTrailingNewlines; + + public TextFileMultiIterator(IOChannelFactory factory, + Iterator inputs, boolean stripTrailingNewlines) { + super(inputs); + this.factory = factory; + this.stripTrailingNewlines = stripTrailingNewlines; + } + + @Override + protected SourceIterator open(String input) throws IOException { + return newSourceIteratorForFile(factory, input, stripTrailingNewlines); + } + } + + class TextFileIterator extends FileBasedIterator { + private final boolean stripTrailingNewlines; + private ScanState state; + + TextFileIterator(CopyableSeekableByteChannel seeker, + boolean stripTrailingNewlines, + long startOffset, + @Nullable Long endOffset) throws IOException { + this(seeker, stripTrailingNewlines, startOffset, startOffset, endOffset, + new ProgressTrackerGroup() { + @Override + protected void report(Integer lineLength) { + notifyElementRead(lineLength.longValue()); + } + }.start(), new ScanState(BUF_SIZE, !stripTrailingNewlines)); + } + + private TextFileIterator(CopyableSeekableByteChannel seeker, + boolean stripTrailingNewlines, + long startOffset, + long offset, + @Nullable Long endOffset, + ProgressTracker tracker, + ScanState state) throws IOException { + super(seeker, startOffset, offset, endOffset, tracker); + + this.stripTrailingNewlines = stripTrailingNewlines; + this.state = state; + } + + private TextFileIterator(TextFileIterator it) throws IOException { + this(it.seeker.copy(), it.stripTrailingNewlines, + /* Correctly adjust the start position of the seeker given + * that it may hold bytes that have been read and now reside + * in the read buffer (that is copied during cloning) */ + it.startOffset + it.state.totalBytesRead, + it.offset, + it.endOffset, it.tracker.copy(), it.state.copy()); + } + + @Override + public SourceIterator copy() throws IOException { + return new TextFileIterator(this); + } + + /** + * Reads a line of text. A line is considered to be terminated by any + * one of a line feed ({@code '\n'}), a carriage return + * ({@code '\r'}), or a carriage return followed immediately by a linefeed + * ({@code "\r\n"}). + * + * @return a {@code ByteArrayOutputStream} containing the contents of the + * line, with any line-termination characters stripped if + * keepNewlines==false, or {@code null} if the end of the stream has + * been reached. + * @throws IOException if an I/O error occurs + */ + @Override + protected ByteArrayOutputStream readElement() + throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(BUF_SIZE); + + int charsConsumed = 0; + while (true) { + // Attempt to read blocks of data at a time + // until a separator is found. + if (!state.readBytes(stream)) { + break; + } + + int consumed = state.consumeUntilSeparator(buffer); + charsConsumed += consumed; + if (consumed > 0 && state.separatorFound()) { + if (state.lastByteRead() == '\r') { + charsConsumed += state.copyCharIfLinefeed(buffer, stream); + } + break; + } + } + + if (charsConsumed == 0) { + // Note that charsConsumed includes the size of any separators that may + // have been stripped off -- so if we didn't get anything, we're at the + // end of the file. + return null; + } + + offset += charsConsumed; + tracker.saw(charsConsumed); + return buffer; + } + } + + /** + * ScanState encapsulates the state for the current buffer of text + * being scanned. + */ + private static class ScanState { + private int start; // Valid bytes in buf start at this index + private int pos; // Where the separator is in the buf (if one was found) + private int end; // the index of the end of bytes in buf + private byte[] buf; + private boolean keepNewlines; + private byte lastByteRead; + private long totalBytesRead; + + public ScanState(int size, boolean keepNewlines) { + this.start = 0; + this.pos = 0; + this.end = 0; + this.buf = new byte[size]; + this.keepNewlines = keepNewlines; + totalBytesRead = 0; + } + + public ScanState copy() { + byte[] bufCopy = new byte[buf.length]; // copy :( + System.arraycopy(buf, start, bufCopy, start, end - start); + return new ScanState( + this.keepNewlines, this.start, this.pos, this.end, + bufCopy, this.lastByteRead, 0); + } + + private ScanState( + boolean keepNewlines, int start, int pos, int end, + byte[] buf, byte lastByteRead, long totalBytesRead) { + this.start = start; + this.pos = pos; + this.end = end; + this.buf = buf; + this.keepNewlines = keepNewlines; + this.lastByteRead = lastByteRead; + this.totalBytesRead = totalBytesRead; + } + + public boolean readBytes(PushbackInputStream stream) throws IOException { + if (start < end) { + return true; + } + assert end <= buf.length : end + " > " + buf.length; + int bytesRead = stream.read(buf, end, buf.length - end); + if (bytesRead == -1) { + return false; + } + totalBytesRead += bytesRead; + end += bytesRead; + return true; + } + + /** + * Consumes characters until a separator character is found or the + * end of buffer is reached. + * + * Updates the state to indicate the position of the separator + * character. If pos==len, no separator was found. + * + * @return the number of characters consumed. + */ + public int consumeUntilSeparator(ByteArrayOutputStream out) { + for (pos = start; pos < end; ++pos) { + lastByteRead = buf[pos]; + if (separatorFound()) { + int charsConsumed = (pos - start + 1); // The separator is consumed + copyToOutputBuffer(out); + start = pos + 1; // skip the separator + return charsConsumed; + } + } + // No separator found + assert pos == end; + int charsConsumed = (pos - start); + out.write(buf, start, charsConsumed); + start = 0; + end = 0; + pos = 0; + return charsConsumed; + } + + public boolean separatorFound() { + return lastByteRead == '\n' || lastByteRead == '\r'; + } + + public byte lastByteRead() { + return buf[pos]; + } + + public int bytesBuffered() { + assert end >= start : end + " must be >= " + start; + return end - start; + } + + /** + * Copies data from the input buffer to the output buffer. + * + * If keepNewlines==true, line-termination characters are included in the copy. + */ + private void copyToOutputBuffer(ByteArrayOutputStream out) { + int charsCopied = pos - start; + if (keepNewlines && separatorFound()) { + charsCopied++; + } + out.write(buf, start, charsCopied); + } + + /** + * Scans the input buffer to determine if a matched carriage return + * has an accompanying linefeed and process the input buffer accordingly. + * + * If keepNewlines==true and a linefeed character is detected, + * it is included in the copy. + * + * @return the number of characters consumed + */ + private int copyCharIfLinefeed(ByteArrayOutputStream out, PushbackInputStream stream) + throws IOException { + int charsConsumed = 0; + // Check to make sure we don't go off the end of the buffer + if ((pos + 1) < end) { + if (buf[pos + 1] == '\n') { + charsConsumed++; + pos++; + start++; + if (keepNewlines) { + out.write('\n'); + } + } + } else { + // We are at the end of the buffer and need one more + // byte. Get it the slow but safe way. + int b = stream.read(); + if (b == '\n') { + charsConsumed++; + totalBytesRead++; + if (keepNewlines) { + out.write(b); + } + } else if (b != -1) { + // Consider replacing unread() since it may be slow if + // iterators are cloned frequently. + stream.unread(b); + } + } + return charsConsumed; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactory.java new file mode 100644 index 000000000000..a15c2d505c47 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactory.java @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.getBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.getLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.api.services.dataflow.model.Source; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Serializer; + +/** + * Creates a TextSource from a CloudObject spec. + */ +public class TextSourceFactory { + // Do not instantiate. + private TextSourceFactory() {} + + public static TextSource create(PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(spec, coder); + } + + static TextSource create(CloudObject spec, + Coder coder) + throws Exception { + return create(spec, coder, true); + } + + public static TextSource create(Source spec) + throws Exception { + return create( + CloudObject.fromSpec(spec.getSpec()), + Serializer.deserialize(spec.getCodec(), Coder.class)); + } + + static TextSource create(CloudObject spec, + Coder coder, + boolean useDefaultBufferSize) throws Exception { + String filenameOrPattern = getString(spec, PropertyNames.FILENAME, null); + if (filenameOrPattern == null) { + filenameOrPattern = getString(spec, PropertyNames.FILEPATTERN, null); + } + return new TextSource<>( + filenameOrPattern, + getBoolean(spec, PropertyNames.STRIP_TRAILING_NEWLINES, true), + getLong(spec, PropertyNames.START_OFFSET, null), + getLong(spec, PropertyNames.END_OFFSET, null), + coder, + useDefaultBufferSize); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSource.java new file mode 100644 index 000000000000..d7d0cf7cf841 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSource.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.BatchingShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.io.IOException; +import java.util.Iterator; + +import javax.annotation.Nullable; + +/** + * A source that reads from a shuffled dataset, without any key grouping. + * Returns just the values. (This reader is for an UNGROUPED shuffle session.) + * + * @param the type of the elements read from the source + */ +public class UngroupedShuffleSource extends Source { + final byte[] shuffleReaderConfig; + final String startShufflePosition; + final String stopShufflePosition; + final Coder coder; + + public UngroupedShuffleSource(PipelineOptions options, + byte[] shuffleReaderConfig, + @Nullable String startShufflePosition, + @Nullable String stopShufflePosition, + Coder coder) { + this.shuffleReaderConfig = shuffleReaderConfig; + this.startShufflePosition = startShufflePosition; + this.stopShufflePosition = stopShufflePosition; + this.coder = coder; + } + + @Override + public SourceIterator iterator() throws IOException { + Preconditions.checkArgument(shuffleReaderConfig != null); + return iterator(new BatchingShuffleEntryReader( + new ChunkingShuffleBatchReader(new ApplianceShuffleReader( + shuffleReaderConfig)))); + } + + SourceIterator iterator(ShuffleEntryReader reader) throws IOException { + return new UngroupedShuffleSourceIterator(reader); + } + + /** + * A SourceIterator that reads from a ShuffleEntryReader and extracts + * just the values. + */ + class UngroupedShuffleSourceIterator extends AbstractSourceIterator { + Iterator iterator; + + UngroupedShuffleSourceIterator(ShuffleEntryReader reader) + throws IOException { + this.iterator = reader.read( + ByteArrayShufflePosition.fromBase64(startShufflePosition), + ByteArrayShufflePosition.fromBase64(stopShufflePosition)); + } + + @Override + public boolean hasNext() throws IOException { + return iterator.hasNext(); + } + + @Override + public T next() throws IOException { + ShuffleEntry record = iterator.next(); + // Throw away the primary and the secondary keys. + byte[] value = record.getValue(); + notifyElementRead(record.length()); + return CoderUtils.decodeFromByteArray(coder, value); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceFactory.java new file mode 100644 index 000000000000..adff71226d6b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceFactory.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.decodeBase64; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +/** + * Creates an UngroupedShuffleSource from a CloudObject spec. + */ +public class UngroupedShuffleSourceFactory { + // Do not instantiate. + private UngroupedShuffleSourceFactory() {} + + public static UngroupedShuffleSource create( + PipelineOptions options, + CloudObject spec, + Coder coder, + ExecutionContext executionContext) + throws Exception { + return create(options, spec, coder); + } + + static UngroupedShuffleSource create( + PipelineOptions options, + CloudObject spec, + Coder coder) + throws Exception { + return new UngroupedShuffleSource<>( + options, + decodeBase64(getString(spec, PropertyNames.SHUFFLE_READER_CONFIG)), + getString(spec, PropertyNames.START_SHUFFLE_POSITION, null), + getString(spec, PropertyNames.END_SHUFFLE_POSITION, null), + coder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatter.java new file mode 100644 index 000000000000..85805773c706 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatter.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker.logging; + +import static com.google.cloud.dataflow.sdk.runners.worker.logging.DataflowWorkerLoggingInitializer.LEVELS; + +import com.google.common.base.MoreObjects; + +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; +import org.slf4j.MDC; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.logging.Formatter; +import java.util.logging.LogRecord; + +/** + * Formats {@link LogRecord} into the following format: + * ISO8601Date LogLevel JobId WorkerId WorkId ThreadId LoggerName LogMessage + * with one or more additional lines for any {@link Throwable} associated with + * the {@link LogRecord}. The exception is output using + * {@link Throwable#printStackTrace()}. + */ +public class DataflowWorkerLoggingFormatter extends Formatter { + private static final DateTimeFormatter DATE_FORMATTER = + ISODateTimeFormat.dateTime().withZoneUTC(); + public static final String MDC_DATAFLOW_JOB_ID = "dataflow.jobId"; + public static final String MDC_DATAFLOW_WORKER_ID = "dataflow.workerId"; + public static final String MDC_DATAFLOW_WORK_ID = "dataflow.workId"; + + @Override + public String format(LogRecord record) { + String exception = formatException(record.getThrown()); + return DATE_FORMATTER.print(record.getMillis()) + + " " + MoreObjects.firstNonNull(LEVELS.get(record.getLevel()), + record.getLevel().getName()) + + " " + MoreObjects.firstNonNull(MDC.get(MDC_DATAFLOW_JOB_ID), "unknown") + + " " + MoreObjects.firstNonNull(MDC.get(MDC_DATAFLOW_WORKER_ID), "unknown") + + " " + MoreObjects.firstNonNull(MDC.get(MDC_DATAFLOW_WORK_ID), "unknown") + + " " + record.getThreadID() + + " " + record.getLoggerName() + + " " + record.getMessage() + "\n" + + (exception != null ? exception : ""); + } + + /** + * Formats the throwable as per {@link Throwable#printStackTrace()}. + * + * @param thrown The throwable to format. + * @return A string containing the contents of {@link Throwable#printStackTrace()}. + */ + private String formatException(Throwable thrown) { + if (thrown == null) { + return null; + } + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + thrown.printStackTrace(pw); + pw.close(); + return sw.toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializer.java new file mode 100644 index 000000000000..80ccf7084bcb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializer.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker.logging; + +import com.google.common.collect.ImmutableBiMap; + +import java.io.IOException; +import java.util.logging.ConsoleHandler; +import java.util.logging.FileHandler; +import java.util.logging.Formatter; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogManager; +import java.util.logging.Logger; + +/** + * Sets up java.util.Logging configuration on the Dataflow Worker Harness with a + * console and file logger. The console and file loggers use the + * {@link DataflowWorkerLoggingFormatter} format. A user can override + * the logging level and location by specifying the Java system properties + * "dataflow.worker.logging.level" and "dataflow.worker.logging.location" respectively. + * The default log level is INFO and the default location is a file named dataflow-worker.log + * within the systems temporary directory. + */ +public class DataflowWorkerLoggingInitializer { + private static final String DEFAULT_LOGGING_LOCATION = "/tmp/dataflow-worker.log"; + private static final String ROOT_LOGGER_NAME = ""; + public static final String DATAFLOW_WORKER_LOGGING_LEVEL = "dataflow.worker.logging.level"; + public static final String DATAFLOW_WORKER_LOGGING_LOCATION = "dataflow.worker.logging.location"; + public static final ImmutableBiMap LEVELS = + ImmutableBiMap.builder() + .put(Level.SEVERE, "ERROR") + .put(Level.WARNING, "WARNING") + .put(Level.INFO, "INFO") + .put(Level.FINE, "DEBUG") + .put(Level.FINEST, "TRACE") + .build(); + private static final String DEFAULT_LOG_LEVEL = LEVELS.get(Level.INFO); + + public void initialize() { + initialize(LogManager.getLogManager()); + } + + void initialize(LogManager logManager) { + try { + Level logLevel = LEVELS.inverse().get( + System.getProperty(DATAFLOW_WORKER_LOGGING_LEVEL, DEFAULT_LOG_LEVEL)); + Formatter formatter = new DataflowWorkerLoggingFormatter(); + + FileHandler fileHandler = new FileHandler( + System.getProperty(DATAFLOW_WORKER_LOGGING_LOCATION, DEFAULT_LOGGING_LOCATION), + true /* Append so that we don't squash existing logs */); + fileHandler.setFormatter(formatter); + fileHandler.setLevel(logLevel); + + ConsoleHandler consoleHandler = new ConsoleHandler(); + consoleHandler.setFormatter(formatter); + consoleHandler.setLevel(logLevel); + + // Reset the global log manager, get the root logger and remove the default log handlers. + logManager.reset(); + Logger rootLogger = logManager.getLogger(ROOT_LOGGER_NAME); + for (Handler handler : rootLogger.getHandlers()) { + rootLogger.removeHandler(handler); + } + + rootLogger.setLevel(logLevel); + rootLogger.addHandler(consoleHandler); + rootLogger.addHandler(fileHandler); + } catch (SecurityException | IOException e) { + throw new ExceptionInInitializerError(e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java new file mode 100644 index 000000000000..615ed6474392 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Implementation of the harness that runs on each Google Compute Engine instance to coordinate + * execution of Pipeline code. + */ +@ParametersAreNonnullByDefault +package com.google.cloud.dataflow.sdk.runners.worker; + +import javax.annotation.ParametersAreNonnullByDefault; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java new file mode 100644 index 000000000000..d4fe32ffd86f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java @@ -0,0 +1,374 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Collection; + +/** + * An assertion on the contents of a {@link PCollection} + * incorporated into the pipeline. Such an assertion + * can be checked no matter what kind of + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} is + * used, so it's good for testing using the + * {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner}, + * the + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}, + * etc. + * + *

Note that the {@code DataflowAssert} call must precede the call + * to {@link com.google.cloud.dataflow.sdk.Pipeline#run}. + * + *

Examples of use: + *

{@code
+ * Pipeline p = TestPipeline.create();
+ * ...
+ * PCollection output =
+ *      input
+ *      .apply(ParDo.of(new TestDoFn()));
+ * DataflowAssert.that(output)
+ *     .containsInAnyOrder("out1", "out2", "out3");
+ * ...
+ * PCollection ints = ...
+ * PCollection sum =
+ *     ints
+ *     .apply(Combine.globally(new SumInts()));
+ * DataflowAssert.that(sum)
+ *     .is(42);
+ * ...
+ * p.run();
+ * }
+ * + *

JUnit and Hamcrest must be linked in by any code that uses DataflowAssert. + * + * @param The type of elements in the input collection. + */ +public class DataflowAssert { + /** + * Constructs an IterableAssert for the elements of the provided + * {@code PCollection}. + */ + public static IterableAssert that(PCollection futureResult) { + return new IterableAssert<>(futureResult.apply(View.asIterable())); + } + + /** + * Constructs an IterableAssert for the value of the provided + * {@code PCollection>}, which must be a singleton. + */ + public static IterableAssert thatSingletonIterable( + PCollection> futureResult) { + return new IterableAssert<>(futureResult.apply(View.>asSingleton())); + } + + /** + * Constructs an IterableAssert for the value of the provided + * {@code PCollectionView, ?>}. + */ + public static IterableAssert thatIterable( + PCollectionView, ?> futureResult) { + return new IterableAssert<>(futureResult); + } + + /** + * An assertion about the contents of a {@link PCollectionView<, ?>} + */ + public static class IterableAssert implements Serializable { + private final PCollectionView, ?> actualResults; + + private IterableAssert(PCollectionView, ?> futureResult) { + actualResults = futureResult; + } + + /** + * Applies a SerializableFunction to check the elements of the Iterable. + * + *

Returns this IterableAssert. + */ + public IterableAssert satisfies( + final SerializableFunction, Void> checkerFn) { + + actualResults.getPipeline() + .apply(Create.of((Void) null)) + .setCoder(VoidCoder.of()) + .apply(ParDo + .withSideInputs(actualResults) + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + Iterable actualContents = c.sideInput(actualResults); + checkerFn.apply(actualContents); + } + })); + + return this; + } + + /** + * Checks that the Iterable contains the expected elements, in any + * order. + * + *

Returns this IterableAssert. + */ + public IterableAssert containsInAnyOrder(T... expectedElements) { + return this.satisfies(new AssertContainsInAnyOrder(expectedElements)); + } + + /** + * Checks that the Iterable contains the expected elements, in any + * order. + * + *

Returns this IterableAssert. + */ + public IterableAssert containsInAnyOrder( + Collection expectedElements) { + return this.satisfies(new AssertContainsInAnyOrder(expectedElements)); + } + + /** + * Checks that the Iterable contains the expected elements, in the + * specified order. + * + *

Returns this IterableAssert. + */ + public IterableAssert containsInOrder(T... expectedElements) { + return this.satisfies(new AssertContainsInOrder(expectedElements)); + } + + /** + * Checks that the Iterable contains the expected elements, in the + * specified order. + * + *

Returns this IterableAssert. + */ + public IterableAssert containsInOrder(Collection expectedElements) { + return this.satisfies(new AssertContainsInOrder(expectedElements)); + } + + /** + * SerializableFunction that performs an {@code Assert.assertThat()} + * operation using a {@code Matcher} operation that takes an array + * of elements. + */ + static class AssertThatIterable extends AssertThat, T[]> { + AssertThatIterable(T[] expected, + String matcherClassName, + String matcherFactoryMethodName) { + super(expected, Object[].class, + matcherClassName, matcherFactoryMethodName); + } + } + + /** + * SerializableFunction that verifies that an Iterable contains + * expected items in any order. + */ + static class AssertContainsInAnyOrder extends AssertThatIterable { + AssertContainsInAnyOrder(T... expected) { + super(expected, + "org.hamcrest.collection.IsIterableContainingInAnyOrder", + "containsInAnyOrder"); + } + @SuppressWarnings("unchecked") + AssertContainsInAnyOrder(Collection expected) { + this((T[]) expected.toArray()); + } + } + + /** + * SerializableFunction that verifies that an Iterable contains + * expected items in the provided order. + */ + static class AssertContainsInOrder extends AssertThatIterable { + AssertContainsInOrder(T... expected) { + super(expected, + "org.hamcrest.collection.IsIterableContainingInOrder", + "contains"); + } + @SuppressWarnings("unchecked") + AssertContainsInOrder(Collection expected) { + this((T[]) expected.toArray()); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Constructs a SingletonAssert for the value of the provided + * {@code PCollection}, which must be a singleton. + */ + public static SingletonAssert thatSingleton(PCollection futureResult) { + return new SingletonAssert<>(futureResult.apply(View.asSingleton())); + } + + /** + * An assertion about a single value. + */ + public static class SingletonAssert implements Serializable { + private final PCollectionView actualResult; + + private SingletonAssert(PCollectionView futureResult) { + actualResult = futureResult; + } + + /** + * Applies a SerializableFunction to check the value of this + * SingletonAssert's view. + * + *

Returns this SingletonAssert. + */ + public SingletonAssert satisfies(final SerializableFunction checkerFn) { + actualResult.getPipeline() + .apply(Create.of((Void) null)) + .setCoder(VoidCoder.of()) + .apply(ParDo + .withSideInputs(actualResult) + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + T actualContents = c.sideInput(actualResult); + checkerFn.apply(actualContents); + } + })); + + return this; + } + + /** + * Checks that the value of this SingletonAssert's view is equal + * to the expected value. + * + *

Returns this SingletonAssert. + */ + public SingletonAssert is(T expectedValue) { + return this.satisfies(new AssertIs(expectedValue)); + } + + /** + * SerializableFunction that performs an {@code Assert.assertThat()} + * operation using a {@code Matcher} operation that takes a single element. + */ + static class AssertThatValue extends AssertThat { + AssertThatValue(T expected, + String matcherClassName, + String matcherFactoryMethodName) { + super(expected, Object.class, + matcherClassName, matcherFactoryMethodName); + } + } + + /** + * SerializableFunction that verifies that a value is equal to an + * expected value. + */ + public static class AssertIs extends AssertThatValue { + AssertIs(T expected) { + super(expected, "org.hamcrest.core.IsEqual", "equalTo"); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + + // Do not instantiate. + private DataflowAssert() {} + + /** + * SerializableFunction that performs an {@code Assert.assertThat()} + * operation using a {@code Matcher} operation. + * + *

The MatcherFactory should take an {@code Expected} and + * produce a Matcher to be used to check an {@code Actual} value + * against. + */ + public static class AssertThat + implements SerializableFunction { + final Expected expected; + final Class expectedClass; + final String matcherClassName; + final String matcherFactoryMethodName; + + AssertThat(Expected expected, + Class expectedClass, + String matcherClassName, + String matcherFactoryMethodName) { + this.expected = expected; + this.expectedClass = expectedClass; + this.matcherClassName = matcherClassName; + this.matcherFactoryMethodName = matcherFactoryMethodName; + } + + @Override + public Void apply(Actual in) { + try { + Method matcherFactoryMethod = Class.forName(this.matcherClassName) + .getMethod(this.matcherFactoryMethodName, expectedClass); + Object matcher = matcherFactoryMethod.invoke(null, (Object) expected); + Method assertThatMethod = Class.forName("org.junit.Assert") + .getMethod("assertThat", + Object.class, + Class.forName("org.hamcrest.Matcher")); + assertThatMethod.invoke(null, in, matcher); + } catch (InvocationTargetException e) { + // An error in the assertThat or matcher itself. + throw new RuntimeException(e); + } catch (ReflectiveOperationException e) { + // An error looking up the classes and methods. + throw new RuntimeException( + "DataflowAssert requires that JUnit and Hamcrest be linked in.", + e); + } + return null; + } + } + + /** + * SerializableFunction that performs an {@code Assert.assertThat()} + * operation using a {@code Matcher} operation that takes a single element. + */ + static class AssertThatValue extends AssertThat { + AssertThatValue(T expected, + String matcherClassName, + String matcherFactoryMethodName) { + super(expected, Object.class, + matcherClassName, matcherFactoryMethodName); + } + } + + /** + * SerializableFunction that verifies that a value is equal to an + * expected value. + */ + public static class AssertIs extends AssertThatValue { + public AssertIs(T expected) { + super(expected, "org.hamcrest.core.IsEqual", "equalTo"); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java new file mode 100644 index 000000000000..048ea36a2533 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +/** + * Category tag for tests that can be run on the DataflowPipelineRunner if the + * runIntegrationTestOnService System property is set to true. + * Example usage: + *


+ *     {@literal @}Test
+ *     {@literal @}Category(RunnableOnService.class)
+ *     public void testParDo() {...
+ * 
+ */ +public interface RunnableOnService {} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java new file mode 100644 index 000000000000..e9f8f828120f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; + +/** + * A set of options used to configure the {@link TestPipeline}. + */ +public interface TestDataflowPipelineOptions extends BlockingDataflowPipelineOptions { + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java new file mode 100644 index 000000000000..96da50189a90 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; + +/** + * TestDataflowPipelineRunner is a pipeline runner that wraps a + * DataflowPipelineRunner when running tests against the {@link TestPipeline}. + * + * @see TestPipeline + */ +public class TestDataflowPipelineRunner extends BlockingDataflowPipelineRunner { + TestDataflowPipelineRunner( + DataflowPipelineRunner internalRunner, + MonitoringUtil.JobMessagesHandler jobMessagesHandler) { + super(internalRunner, jobMessagesHandler); + } + + @Override + public PipelineJobState run(Pipeline pipeline) { + PipelineJobState state = super.run(pipeline); + if (state.getJobState() != MonitoringUtil.JobState.DONE) { + throw new AssertionError("The dataflow failed."); + } + return state; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java new file mode 100644 index 000000000000..6044365a664d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.common.base.Optional; +import com.google.common.collect.Iterators; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Iterator; + +/** + * A creator of test pipelines which can be used inside of tests that can be + * configured to run locally or against the live service. + * + *

It is recommended to tag hand-selected tests for this purpose using the + * RunnableOnService Category annotation, as each test run against the service + * will spin up and tear down a single VM. + * + *

In order to run tests on the dataflow pipeline service, the following + * conditions must be met: + *

    + *
  • runIntegrationTestOnService System property must be set to true. + *
  • System property "projectName" must be set to your Cloud project. + *
  • System property "temp_gcs_directory" must be set to a valid GCS bucket. + *
  • Jars containing the SDK and test classes must be added to the test classpath. + *
+ * + *

Use {@link DataflowAssert} for tests, as it integrates with this test + * harness in both direct and remote execution modes. For example: + * + *

{@code
+ * Pipeline p = TestPipeline.create();
+ * PCollection output = ...
+ *
+ * DataflowAssert.that(output)
+ *     .containsInAnyOrder(1, 2, 3, 4);
+ * p.run();
+ * }
+ * + */ +public class TestPipeline extends Pipeline { + private static final String PROPERTY_DATAFLOW_OPTIONS = "dataflowOptions"; + private static final Logger LOG = LoggerFactory.getLogger(TestPipeline.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** + * Creates and returns a new test pipeline. + * + *

Use {@link DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + */ + public static TestPipeline create() { + if (Boolean.parseBoolean(System.getProperty("runIntegrationTestOnService"))) { + TestDataflowPipelineOptions options = getPipelineOptions(); + LOG.info("Using passed in options: " + options); + return new TestPipeline(createRunner(options), options); + } else { + DirectPipelineRunner directRunner = DirectPipelineRunner.createForTest(); + return new TestPipeline(directRunner, directRunner.getPipelineOptions()); + } + } + + private TestPipeline(PipelineRunner runner, PipelineOptions options) { + super(runner, options); + } + + /** + * Creates and returns a TestDataflowPipelineRunner based on + * configuration via system properties. + */ + private static TestDataflowPipelineRunner createRunner( + TestDataflowPipelineOptions options) { + + DataflowPipelineRunner dataflowRunner = DataflowPipelineRunner + .fromOptions(options); + return new TestDataflowPipelineRunner(dataflowRunner, + new MonitoringUtil.PrintHandler(options.getJobMessageOutput())); + } + + /** + * Creates PipelineOptions for testing with a DataflowPipelineRunner. + */ + static TestDataflowPipelineOptions getPipelineOptions() { + try { + TestDataflowPipelineOptions options = MAPPER.readValue( + System.getProperty(PROPERTY_DATAFLOW_OPTIONS), PipelineOptions.class) + .as(TestDataflowPipelineOptions.class); + options.setAppName(getAppName()); + options.setJobName(getJobName()); + return options; + } catch (IOException e) { + throw new RuntimeException("Unable to instantiate test options from system property " + + PROPERTY_DATAFLOW_OPTIONS + ":" + System.getProperty(PROPERTY_DATAFLOW_OPTIONS), e); + } + } + + /** Returns the class name of the test, or a default name. */ + private static String getAppName() { + Optional stackTraceElement = findCallersStackTrace(); + if (stackTraceElement.isPresent()) { + String className = stackTraceElement.get().getClassName(); + return className.contains(".") + ? className.substring(className.lastIndexOf(".") + 1) + : className; + } + return "UnitTest"; + } + + /** Returns the method name of the test, or a default name. */ + private static String getJobName() { + Optional stackTraceElement = findCallersStackTrace(); + if (stackTraceElement.isPresent()) { + return stackTraceElement.get().getMethodName(); + } + return "unittestjob"; + } + + /** Returns the {@link StackTraceElement} of the calling class. */ + private static Optional findCallersStackTrace() { + Iterator elements = + Iterators.forArray(Thread.currentThread().getStackTrace()); + // First find the TestPipeline class in the stack trace. + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (TestPipeline.class.getName().equals(next.getClassName())) { + break; + } + } + // Then find the first instance after which is not the TestPipeline + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (!TestPipeline.class.getName().equals(next.getClassName())) { + return Optional.of(next); + } + } + return Optional.absent(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowingFnTestUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowingFnTestUtils.java new file mode 100644 index 000000000000..687cb64530ef --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowingFnTestUtils.java @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A utility class for testing {@link WindowingFn}s. + */ +public class WindowingFnTestUtils { + + /** + * Creates a Set of elements to be used as expected output in + * {@link #runWindowingFn}. + */ + public static Set set(long... timestamps) { + Set result = new HashSet<>(); + for (long timestamp : timestamps) { + result.add(timestampValue(timestamp)); + } + return result; + } + + + /** + * Runs the {@link WindowingFn} over the provided input, returning a map + * of windows to the timestamps in those windows. + */ + public static Map> runWindowingFn( + WindowingFn windowingFn, + List timestamps) throws Exception { + + final TestWindowSet windowSet = new TestWindowSet(); + for (final Long timestamp : timestamps) { + for (W window : windowingFn.assignWindows( + new TestAssignContext(new Instant(timestamp), windowingFn))) { + windowSet.put(window, timestampValue(timestamp)); + } + windowingFn.mergeWindows(new TestMergeContext(windowSet, windowingFn)); + } + Map> actual = new HashMap<>(); + for (W window : windowSet.windows()) { + actual.put(window, windowSet.get(window)); + } + return actual; + } + + private static String timestampValue(long timestamp) { + return "T" + new Instant(timestamp); + } + + /** + * Test implementation of AssignContext. + */ + private static class TestAssignContext + extends WindowingFn.AssignContext { + private Instant timestamp; + + public TestAssignContext(Instant timestamp, WindowingFn windowingFn) { + windowingFn.super(); + this.timestamp = timestamp; + } + + @Override + public T element() { + return null; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public Collection windows() { + return null; + } + } + + /** + * Test implementation of MergeContext. + */ + private static class TestMergeContext + extends WindowingFn.MergeContext { + private TestWindowSet windowSet; + + public TestMergeContext( + TestWindowSet windowSet, WindowingFn windowingFn) { + windowingFn.super(); + this.windowSet = windowSet; + } + + @Override + public Collection windows() { + return windowSet.windows(); + } + + @Override + public void merge(Collection toBeMerged, W mergeResult) { + windowSet.merge(toBeMerged, mergeResult); + } + } + + /** + * A WindowSet useful for testing WindowingFns which simply + * collects the placed elements into multisets. + */ + private static class TestWindowSet { + + private Map> elements = new HashMap<>(); + private List> emitted = new ArrayList<>(); + + public void put(W window, V value) { + Set all = elements.get(window); + if (all == null) { + all = new HashSet<>(); + elements.put(window, all); + } + all.add(value); + } + + public void remove(W window) { + elements.remove(window); + } + + public void merge(Collection otherWindows, W window) { + if (otherWindows.isEmpty()) { + return; + } + Set merged = new HashSet<>(); + if (elements.containsKey(window) && !otherWindows.contains(window)) { + merged.addAll(elements.get(window)); + } + for (W w : otherWindows) { + if (!elements.containsKey(w)) { + throw new IllegalArgumentException("Tried to merge a non-existent window:" + w); + } + merged.addAll(elements.get(w)); + elements.remove(w); + } + elements.put(window, merged); + } + + public void markCompleted(W window) {} + + public Collection windows() { + return elements.keySet(); + } + + public boolean contains(W window) { + return elements.containsKey(window); + } + + // For testing. + + public Set get(W window) { + return elements.get(window); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java new file mode 100644 index 000000000000..799c1ac98bc8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines utilities for unit testing Dataflow pipelines. The tests for the {@code PTransform}s and + * examples included the Dataflow SDK provide examples of using these utilities. + */ +package com.google.cloud.dataflow.sdk.testing; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java new file mode 100644 index 000000000000..13ad17efa702 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.transforms; + +/** + * An {@code Aggregator} enables arbitrary monitoring in user code. + * + *

Aggregators are created by calling {@link DoFn.Context#createAggregator}, + * typically from {@link DoFn#startBundle}. Elements can be added to the + * {@code Aggregator} by calling {@link Aggregator#addValue}. + * + *

Aggregators are visible in the monitoring UI, when the pipeline is run + * using DataflowPipelineRunner or BlockingDataflowPipelineRunner, along with + * their current value. Aggregators may not become visible until the system + * begins executing the ParDo transform which created them and/or their initial + * value is changed. + * + *

Example: + *

 {@code
+ * class MyDoFn extends DoFn {
+ *   private Aggregator myAggregator;
+ *
+ *   {@literal @}Override
+ *   public void startBundle(Context c) {
+ *     myAggregator = c.createAggregator("myCounter", new Sum.SumIntegerFn());
+ *   }
+ *
+ *   {@literal @}Override
+ *   public void processElement(ProcessContext c) {
+ *     myAggregator.addValue(1);
+ *   }
+ * }
+ * } 
+ * + * @param the type of input values + */ +public interface Aggregator { + + /** + * Adds a new value into the Aggregator. + */ + public void addValue(VI value); + + // TODO: Consider the following additional API conveniences: + // - In addition to createAggregator(), consider adding getAggregator() to + // avoid the need to store the aggregator locally in a DoFn, i.e., create + // if not already present. + // - Add a shortcut for the most common aggregator: + // c.createAggregator("name", new Sum.SumIntegerFn()). +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java new file mode 100644 index 000000000000..ff5687fe30fb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java @@ -0,0 +1,723 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.UnmodifiableIterator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; + +/** + * {@code PTransform}s for getting an idea of a {@code PCollection}'s + * data distribution using approximate {@code N}-tiles, either + * globally or per-key. + */ +public class ApproximateQuantiles { + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>} whose sinlge value is a + * {@code List} of the approximate {@code N}-tiles of the elements + * of the input {@code PCollection}. This gives an idea of the + * distribution of the input elements. + * + *

The computed {@code List} is of size {@code numQuantiles}, + * and contains the input elements' minimum value, + * {@code numQuantiles-2} intermediate values, and maximum value, in + * sorted order, using the given {@code Comparator} to order values. + * To compute traditional {@code N}-tiles, one should use + * {@code ApproximateQuantiles.globally(compareFn, N+1)}. + * + *

If there are fewer input elements than {@code numQuantiles}, + * then the result {@code List} will contain all the input elements, + * in sorted order. + * + *

The argument {@code Comparator} must be {@code Serializable}. + * + *

Example of use: + *

 {@code
+   * PCollection pc = ...;
+   * PCollection> quantiles =
+   *     pc.apply(ApproximateQuantiles.globally(stringCompareFn, 11));
+   * } 
+ * + * @param the type of the elements in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + * @param compareFn the function to use to order the elements + */ + public static & Serializable> + PTransform, PCollection>> globally( + int numQuantiles, C compareFn) { + return Combine.globally( + ApproximateQuantilesCombineFn.create(numQuantiles, compareFn)); + } + + /** + * Like {@link #globally(int, Comparator)}, but sorts using the + * elements' natural ordering. + * + * @param the type of the elements in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + */ + public static > + PTransform, PCollection>> globally(int numQuantiles) { + return Combine.globally( + ApproximateQuantilesCombineFn.create(numQuantiles)); + } + + /** + * Returns a {@code PTransform} that takes a + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to a {@code List} of the approximate + * {@code N}-tiles of the values associated with that key in the + * input {@code PCollection}. This gives an idea of the + * distribution of the input values for each key. + * + *

Each of the computed {@code List}s is of size {@code numQuantiles}, + * and contains the input values' minimum value, + * {@code numQuantiles-2} intermediate values, and maximum value, in + * sorted order, using the given {@code Comparator} to order values. + * To compute traditional {@code N}-tiles, one should use + * {@code ApproximateQuantiles.perKey(compareFn, N+1)}. + * + *

If a key has fewer than {@code numQuantiles} values + * associated with it, then that key's output {@code List} will + * contain all the key's input values, in sorted order. + * + *

The argument {@code Comparator} must be {@code Serializable}. + * + *

Example of use: + *

 {@code
+   * PCollection> pc = ...;
+   * PCollection>> quantilesPerKey =
+   *     pc.apply(ApproximateQuantiles.perKey(stringCompareFn, 11));
+   * } 
+ * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + * @param compareFn the function to use to order the elements + */ + public static & Serializable> + PTransform>, PCollection>>> + perKey(int numQuantiles, C compareFn) { + return Combine.perKey( + ApproximateQuantilesCombineFn.create(numQuantiles, compareFn) + .asKeyedFn()); + } + + /** + * Like {@link #perKey(int, Comparator)}, but sorts + * values using the their natural ordering. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + */ + public static > + PTransform>, PCollection>>> + perKey(int numQuantiles) { + return Combine.perKey( + ApproximateQuantilesCombineFn.create(numQuantiles) + .asKeyedFn()); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * The {@code ApproximateQuantilesCombineFn} combiner gives an idea + * of the distribution of a collection of values using approximate + * {@code N}-tiles. The output of this combiner is a {@code List} + * of size {@code numQuantiles}, containing the input values' + * minimum value, {@code numQuantiles-2} intermediate values, and + * maximum value, in sorted order, so for traditional + * {@code N}-tiles, one should use + * {@code ApproximateQuantilesCombineFn#create(N+1)}. + * + *

If there are fewer values to combine than + * {@code numQuantiles}, then the result {@code List} will contain all the + * values being combined, in sorted order. + * + *

Values are ordered using either a specified + * {@code Comparator} or the values' natural ordering. + * + *

To evaluate the quantiles we use the "New Algorithm" described here: + *

+   *   [MRL98] Manku, Rajagopalan & Lindsay, "Approximate Medians and other
+   *   Quantiles in One Pass and with Limited Memory", Proc. 1998 ACM
+   *   SIGMOD, Vol 27, No 2, p 426-435, June 1998.
+   *   http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1&type=pdf
+   * 
+ * + *

The default error bound is {@code 1 / N}, though in practice + * the accuracy tends to be much better.

See + * {@link #create(int, Comparator, long, double)} for + * more information about the meaning of {@code epsilon}, and + * {@link #withEpsilon} for a convenient way to adjust it. + * + * @param the type of the values being combined + */ + public static class ApproximateQuantilesCombineFn + & Serializable> + extends AccumulatingCombineFn + .QuantileState, List> { + + /** + * The cost (in time and space) to compute quantiles to a given + * accuracy is a function of the total number of elements in the + * data set. If an estimate is not known or specified, we use + * this as an upper bound. If this is too low, errors may exceed + * the requested tolerance; if too high, efficiency may be + * non-optimal. The impact is logarithmic with respect to this + * value, so this default should be fine for most uses. + */ + public static final long DEFAULT_MAX_NUM_ELEMENTS = (long) 1e9; + + /** The comparison function to use. */ + private final C compareFn; + + /** + * Number of quantiles to produce. The size of the final output + * list, including the minimum and maximum, is numQuantiles. + */ + private final int numQuantiles; + + /** The size of the buffers, corresponding to k in the referenced paper. */ + private final int bufferSize; + + /** The number of buffers, corresponding to b in the referenced paper. */ + private final int numBuffers; + + private final double epsilon; + private final long maxNumElements; + + /** + * Used to alternate between biasing up and down in the even weight collapse + * operation. + */ + private int offsetJitter = 0; + + /** + * Returns an approximate quantiles combiner with the given + * {@code compareFn} and desired number of quantiles. A total of + * {@code numQuantiles} elements will appear in the output list, + * including the minimum and maximum. + * + *

The {@code Comparator} must be {@code Serializable}. + * + *

The default error bound is {@code 1 / numQuantiles} which + * holds as long as the number of elements is less than + * {@link #DEFAULT_MAX_NUM_ELEMENTS}. + */ + public static & Serializable> + ApproximateQuantilesCombineFn create( + int numQuantiles, C compareFn) { + return create(numQuantiles, compareFn, + DEFAULT_MAX_NUM_ELEMENTS, 1.0 / numQuantiles); + } + + /** + * Like {@link #create(int, Comparator)}, but sorts + * values using their natural ordering. + */ + public static > + ApproximateQuantilesCombineFn> create(int numQuantiles) { + return create(numQuantiles, new Top.Largest()); + } + + /** + * Returns an {@code ApproximateQuantilesCombineFn} that's like + * this one except that it uses the specified {@code epsilon} + * value. Does not modify this combiner. + * + *

See {@link #create(int, Comparator, long, + * double)} for more information about the meaning of + * {@code epsilon}. + */ + public ApproximateQuantilesCombineFn withEpsilon(double epsilon) { + return create(numQuantiles, compareFn, maxNumElements, epsilon); + } + + /** + * Returns an {@code ApproximateQuantilesCombineFn} that's like + * this one except that it uses the specified {@code maxNumElements} + * value. Does not modify this combiner. + * + *

See {@link #create(int, Comparator, long, double)} for more + * information about the meaning of {@code maxNumElements}. + */ + public ApproximateQuantilesCombineFn withMaxInputSize( + long maxNumElements) { + return create(numQuantiles, compareFn, maxNumElements, maxNumElements); + } + + /** + * Creates an approximate quantiles combiner with the given + * {@code compareFn} and desired number of quantiles. A total of + * {@code numQuantiles} elements will appear in the output list, + * including the minimum and maximum. + * + *

The {@code Comparator} must be {@code Serializable}. + * + *

The default error bound is {@code epsilon} which is holds as long + * as the number of elements is less than {@code maxNumElements}. + * Specifically, if one considers the input as a sorted list x_1, ..., x_N, + * then the distance between the each exact quantile x_c and its + * approximation x_c' is bounded by {@code |c - c'| < epsilon * N}. + * Note that these errors are worst-case scenarios; in practice the accuracy + * tends to be much better. + */ + public static & Serializable> + ApproximateQuantilesCombineFn create( + int numQuantiles, + C compareFn, + long maxNumElements, + double epsilon) { + // Compute optimal b and k. + int b = 2; + while ((b - 2) * (1 << (b - 2)) < epsilon * maxNumElements) { + b++; + } + b--; + int k = Math.max(2, (int) Math.ceil(maxNumElements / (1 << (b - 1)))); + return new ApproximateQuantilesCombineFn<>( + numQuantiles, compareFn, k, b, epsilon, maxNumElements); + } + + private ApproximateQuantilesCombineFn(int numQuantiles, + C compareFn, + int bufferSize, + int numBuffers, + double epsilon, + long maxNumElements) { + Preconditions.checkArgument(numQuantiles >= 2); + Preconditions.checkArgument(bufferSize >= 2); + Preconditions.checkArgument(numBuffers >= 2); + Preconditions.checkArgument(compareFn instanceof Serializable); + this.numQuantiles = numQuantiles; + this.compareFn = compareFn; + this.bufferSize = bufferSize; + this.numBuffers = numBuffers; + this.epsilon = epsilon; + this.maxNumElements = maxNumElements; + } + + @Override + public QuantileState createAccumulator() { + return new QuantileState(); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder elementCoder) { + return new QuantileStateCoder(elementCoder); + } + + /** + * Compact summarization of a collection on which quantiles can be + * estimated. + */ + class QuantileState + extends AccumulatingCombineFn + .QuantileState, List> + .Accumulator { + + private T min; + private T max; + + /** + * The set of buffers, ordered by level from smallest to largest. + */ + private PriorityQueue buffers = + new PriorityQueue<>(numBuffers + 1); + + /** + * The algorithm requires that the manipulated buffers always be filled + * to capacity to perform the collapse operation. This operation can + * be extended to buffers of varying sizes by introducing the notion of + * fractional weights, but it's easier to simply combine the remainders + * from all shards into new, full buffers and then take them into account + * when computing the final output. + */ + private List unbufferedElements = Lists.newArrayList(); + + public QuantileState() { } + + public QuantileState(T elem) { + min = elem; + max = elem; + unbufferedElements.add(elem); + } + + public QuantileState(T min, T max, Collection unbufferedElements, + Collection buffers) { + this.min = min; + this.max = max; + this.unbufferedElements.addAll(unbufferedElements); + this.buffers.addAll(buffers); + } + + /** + * Add a new element to the collection being summarized by this state. + */ + @Override + public void addInput(T elem) { + if (isEmpty()) { + min = max = elem; + } else if (compareFn.compare(elem, min) < 0) { + min = elem; + } else if (compareFn.compare(elem, max) > 0) { + max = elem; + } + addUnbuffered(elem); + } + + /** + * Add a new buffer to the unbuffered list, creating a new buffer and + * collapsing if needed. + */ + private void addUnbuffered(T elem) { + unbufferedElements.add(elem); + if (unbufferedElements.size() == bufferSize) { + Collections.sort(unbufferedElements, compareFn); + buffers.add(new QuantileBuffer(unbufferedElements)); + unbufferedElements = Lists.newArrayListWithCapacity(bufferSize); + collapseIfNeeded(); + } + } + + /** + * Updates this as if adding all elements seen by other. + */ + @Override + public void mergeAccumulator(QuantileState other) { + if (other.isEmpty()) { + return; + } + if (min == null || compareFn.compare(other.min, min) < 0) { + min = other.min; + } + if (max == null || compareFn.compare(other.max, max) > 0) { + max = other.max; + } + for (T elem : other.unbufferedElements) { + addUnbuffered(elem); + } + buffers.addAll(other.buffers); + collapseIfNeeded(); + } + + public boolean isEmpty() { + return unbufferedElements.size() == 0 && buffers.size() == 0; + } + + private void collapseIfNeeded() { + while (buffers.size() > numBuffers) { + List toCollapse = Lists.newArrayList(); + toCollapse.add(buffers.poll()); + toCollapse.add(buffers.poll()); + int minLevel = toCollapse.get(1).level; + while (!buffers.isEmpty() && buffers.peek().level == minLevel) { + toCollapse.add(buffers.poll()); + } + buffers.add(collapse(toCollapse)); + } + } + + private QuantileBuffer collapse(Iterable buffers) { + int newLevel = 0; + long newWeight = 0; + for (QuantileBuffer buffer : buffers) { + // As presented in the paper, there should always be at least two + // buffers of the same (minimal) level to collapse, but it is possible + // to violate this condition when combining buffers from independently + // computed shards. If they differ we take the max. + newLevel = Math.max(newLevel, buffer.level + 1); + newWeight += buffer.weight; + } + List newElements = + interpolate(buffers, bufferSize, newWeight, offset(newWeight)); + return new QuantileBuffer(newLevel, newWeight, newElements); + } + + /** + * Outputs numQuantiles elements consisting of the minimum, maximum, and + * numQuantiles - 2 evenly spaced intermediate elements. + * + * Returns the empty list if no elements have been added. + */ + @Override + public List extractOutput() { + if (isEmpty()) { + return Lists.newArrayList(); + } + long totalCount = unbufferedElements.size(); + for (QuantileBuffer buffer : buffers) { + totalCount += bufferSize * buffer.weight; + } + List all = Lists.newArrayList(buffers); + if (!unbufferedElements.isEmpty()) { + Collections.sort(unbufferedElements, compareFn); + all.add(new QuantileBuffer(unbufferedElements)); + } + double step = 1.0 * totalCount / (numQuantiles - 1); + double offset = (1.0 * totalCount - 1) / (numQuantiles - 1); + List quantiles = interpolate(all, numQuantiles - 2, step, offset); + quantiles.add(0, min); + quantiles.add(max); + return quantiles; + } + } + + /** + * A single buffer in the sense of the referenced algorithm. + */ + private class QuantileBuffer implements Comparable { + private int level; + private long weight; + private List elements; + + public QuantileBuffer(List elements) { + this(0, 1, elements); + } + + public QuantileBuffer(int level, long weight, List elements) { + this.level = level; + this.weight = weight; + this.elements = elements; + } + + @Override + public int compareTo(QuantileBuffer other) { + return this.level - other.level; + } + + @Override + public String toString() { + return "QuantileBuffer[" + + "level=" + level + + ", weight=" + + weight + ", elements=" + elements + "]"; + } + + public Iterator> weightedIterator() { + return new UnmodifiableIterator>() { + Iterator iter = elements.iterator(); + @Override public boolean hasNext() { return iter.hasNext(); } + @Override public WeightedElement next() { + return WeightedElement.of(weight, iter.next()); + } + }; + } + } + + /** + * Coder for QuantileState. + */ + private class QuantileStateCoder extends CustomCoder { + + private final Coder elementCoder; + private final Coder> elementListCoder; + + public QuantileStateCoder(Coder elementCoder) { + this.elementCoder = elementCoder; + this.elementListCoder = ListCoder.of(elementCoder); + } + + @Override + public void encode( + QuantileState state, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + elementCoder.encode(state.min, outStream, nestedContext); + elementCoder.encode(state.max, outStream, nestedContext); + elementListCoder.encode( + state.unbufferedElements, outStream, nestedContext); + BigEndianIntegerCoder.of().encode( + state.buffers.size(), outStream, nestedContext); + for (QuantileBuffer buffer : state.buffers) { + encodeBuffer(buffer, outStream, nestedContext); + } + } + + @Override + public QuantileState decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + T min = elementCoder.decode(inStream, nestedContext); + T max = elementCoder.decode(inStream, nestedContext); + List unbufferedElements = + elementListCoder.decode(inStream, nestedContext); + int numBuffers = + BigEndianIntegerCoder.of().decode(inStream, nestedContext); + List buffers = new ArrayList<>(numBuffers); + for (int i = 0; i < numBuffers; i++) { + buffers.add(decodeBuffer(inStream, nestedContext)); + } + return new QuantileState(min, max, unbufferedElements, buffers); + } + + private void encodeBuffer( + QuantileBuffer buffer, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + DataOutputStream outData = new DataOutputStream(outStream); + outData.writeInt(buffer.level); + outData.writeLong(buffer.weight); + elementListCoder.encode(buffer.elements, outStream, context); + } + + private QuantileBuffer decodeBuffer( + InputStream inStream, Coder.Context context) + throws IOException, CoderException { + DataInputStream inData = new DataInputStream(inStream); + return new QuantileBuffer( + inData.readInt(), + inData.readLong(), + elementListCoder.decode(inStream, context)); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + QuantileState state, + ElementByteSizeObserver observer, + Coder.Context context) + throws Exception { + Coder.Context nestedContext = context.nested(); + elementCoder.registerByteSizeObserver( + state.min, observer, nestedContext); + elementCoder.registerByteSizeObserver( + state.max, observer, nestedContext); + elementListCoder.registerByteSizeObserver( + state.unbufferedElements, observer, nestedContext); + + BigEndianIntegerCoder.of().registerByteSizeObserver( + state.buffers.size(), observer, nestedContext); + for (QuantileBuffer buffer : state.buffers) { + observer.update(4L + 8); + + elementListCoder.registerByteSizeObserver( + buffer.elements, observer, nestedContext); + } + } + + @Override + public boolean isDeterministic() { + return elementListCoder.isDeterministic(); + } + } + + /** + * If the weight is even, we must round up our down. Alternate between + * these two options to avoid a bias. + */ + private long offset(long newWeight) { + if (newWeight % 2 == 1) { + return (newWeight + 1) / 2; + } else { + offsetJitter = 2 - offsetJitter; + return (newWeight + offsetJitter) / 2; + } + } + + /** + * Emulates taking the ordered union of all elements in buffers, repeated + * according to their weight, and picking out the (k * step + offset)-th + * elements of this list for {@code 0 <= k < count}. + */ + private List interpolate(Iterable buffers, + int count, double step, double offset) { + List>> iterators = Lists.newArrayList(); + for (QuantileBuffer buffer : buffers) { + iterators.add(buffer.weightedIterator()); + } + // Each of the buffers is already sorted by element. + Iterator> sorted = Iterators.mergeSorted( + iterators, + new Comparator>() { + @Override + public int compare(WeightedElement a, WeightedElement b) { + return compareFn.compare(a.value, b.value); + } + }); + + List newElements = Lists.newArrayListWithCapacity(count); + WeightedElement weightedElement = sorted.next(); + double current = weightedElement.weight; + for (int j = 0; j < count; j++) { + double target = j * step + offset; + while (current <= target && sorted.hasNext()) { + weightedElement = sorted.next(); + current += weightedElement.weight; + } + newElements.add(weightedElement.value); + } + return newElements; + } + + /** An element and its weight. */ + private static class WeightedElement { + public long weight; + public T value; + private WeightedElement(long weight, T value) { + this.weight = weight; + this.value = value; + } + public static WeightedElement of(long weight, T value) { + return new WeightedElement<>(weight, value); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java new file mode 100644 index 000000000000..9308a010a2a9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.hash.Hashing; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; + +/** + * {@code PTransform}s for estimating the number of distinct elements + * in a {@code PCollection}, or the number of distinct values + * associated with each key in a {@code PCollection} of {@code KV}s. + */ +public class ApproximateUnique { + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection} containing a single value + * that is an estimate of the number of distinct elements in the + * input {@code PCollection}. + * + *

The {@code sampleSize} parameter controls the estimation + * error. The error is about {@code 2 / sqrt(sampleSize)}, so for + * {@code ApproximateUnique.globally(10000)} the estimation error is + * about 2%. Similarly, for {@code ApproximateUnique.of(16)} the + * estimation error is about 50%. If there are fewer than + * {@code sampleSize} distinct elements then the returned result + * will be exact with extremely high probability (the chance of a + * hash collision is about {@code sampleSize^2 / 2^65}). + * + *

This transform approximates the number of elements in a set + * by computing the top {@code sampleSize} hash values, and using + * that to extrapolate the size of the entire set of hash values by + * assuming the rest of the hash values are as densely distributed + * as the top {@code sampleSize}. + * + *

See also {@link #globally(double)}. + * + *

Example of use: + *

 {@code
+   * PCollection pc = ...;
+   * PCollection approxNumDistinct =
+   *     pc.apply(ApproximateUnique.globally(1000));
+   * } 
+ * + * @param the type of the elements in the input {@code PCollection} + * @param sampleSize the number of entries in the statistical + * sample; the higher this number, the more accurate the + * estimate will be; should be {@code >= 16} + * @throws IllegalArgumentException if the {@code sampleSize} + * argument is too small + */ + public static Globally globally(int sampleSize) { + return new Globally<>(sampleSize); + } + + /** + * Like {@link #globally(int)}, but specifies the desired maximum + * estimation error instead of the sample size. + * + * @param the type of the elements in the input {@code PCollection} + * @param maximumEstimationError the maximum estimation error, which + * should be in the range {@code [0.01, 0.5]} + * @throws IllegalArgumentException if the + * {@code maximumEstimationError} argument is out of range + */ + public static Globally globally(double maximumEstimationError) { + return new Globally<>(maximumEstimationError); + } + + /** + * Returns a {@code PTransform} that takes a + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element + * mapping each distinct key in the input {@code PCollection} to an + * estimate of the number of distinct values associated with that + * key in the input {@code PCollection}. + * + *

See {@link #globally(int)} for an explanation of the + * {@code sampleSize} parameter. A separate sampling is computed + * for each distinct key of the input. + * + *

See also {@link #perKey(double)}. + * + *

Example of use: + *

 {@code
+   * PCollection> pc = ...;
+   * PCollection> approxNumDistinctPerKey =
+   *     pc.apply(ApproximateUnique.perKey(1000));
+   * } 
+ * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param sampleSize the number of entries in the statistical + * sample; the higher this number, the more accurate the + * estimate will be; should be {@code >= 16} + * @throws IllegalArgumentException if the {@code sampleSize} + * argument is too small + */ + public static PerKey perKey(int sampleSize) { + return new PerKey<>(sampleSize); + } + + /** + * Like {@link #perKey(int)}, but specifies the desired maximum + * estimation error instead of the sample size. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param maximumEstimationError the maximum estimation error, which + * should be in the range {@code [0.01, 0.5]} + * @throws IllegalArgumentException if the + * {@code maximumEstimationError} argument is out of range + */ + public static PerKey perKey(double maximumEstimationError) { + return new PerKey<>(maximumEstimationError); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code PTransform} for estimating the number of distinct elements + * in a {@code PCollection}. + * + * @param the type of the elements in the input {@code PCollection} + */ + static class Globally extends PTransform, PCollection> { + + /** + * The number of entries in the statistical sample; the higher this number, + * the more accurate the estimate will be. + */ + private final long sampleSize; + + /** + * @see ApproximateUnique#globally(int) + */ + public Globally(int sampleSize) { + if (sampleSize < 16) { + throw new IllegalArgumentException( + "ApproximateUnique needs a sampleSize " + + ">= 16 for an estimation error <= 50%. " + + "In general, the estimation " + + "error is about 2 / sqrt(sampleSize)."); + } + this.sampleSize = sampleSize; + } + + /** + * @see ApproximateUnique#globally(double) + */ + public Globally(double maximumEstimationError) { + if (maximumEstimationError < 0.01 || maximumEstimationError > 0.5) { + throw new IllegalArgumentException( + "ApproximateUnique needs an " + + "estimation error between 1% (0.01) and 50% (0.5)."); + } + this.sampleSize = sampleSizeFromEstimationError(maximumEstimationError); + } + + @Override + public PCollection apply(PCollection input) { + Coder coder = input.getCoder(); + return input.apply( + Combine.globally( + new ApproximateUniqueCombineFn<>(sampleSize, coder))); + } + + @Override + protected String getKindString() { + return "ApproximateUnique.Globally"; + } + } + + /** + * {@code PTransform} for estimating the number of distinct values + * associated with each key in a {@code PCollection} of {@code KV}s. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + */ + static class PerKey + extends PTransform>, PCollection>> { + + private final long sampleSize; + + /** + * @see ApproximateUnique#perKey(int) + */ + public PerKey(int sampleSize) { + if (sampleSize < 16) { + throw new IllegalArgumentException( + "ApproximateUnique needs a " + + "sampleSize >= 16 for an estimation error <= 50%. In general, " + + "the estimation error is about 2 / sqrt(sampleSize)."); + } + this.sampleSize = sampleSize; + } + + /** + * @see ApproximateUnique#perKey(double) + */ + public PerKey(double estimationError) { + if (estimationError < 0.01 || estimationError > 0.5) { + throw new IllegalArgumentException( + "ApproximateUnique.PerKey needs an " + + "estimation error between 1% (0.01) and 50% (0.5)."); + } + this.sampleSize = sampleSizeFromEstimationError(estimationError); + } + + @Override + public PCollection> apply(PCollection> input) { + Coder> inputCoder = input.getCoder(); + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "ApproximateUnique.PerKey requires its input to use KvCoder"); + } + @SuppressWarnings("unchecked") + final Coder coder = ((KvCoder) inputCoder).getValueCoder(); + + return input.apply( + Combine.perKey(new ApproximateUniqueCombineFn<>( + sampleSize, coder).asKeyedFn())); + } + + @Override + protected String getKindString() { + return "ApproximateUnique.PerKey"; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code CombineFn} that computes an estimate of the number of + * distinct values that were combined. + * + *

Hashes input elements, computes the top {@code sampleSize} + * hash values, and uses those to extrapolate the size of the entire + * set of hash values by assuming the rest of the hash values are as + * densely distributed as the top {@code sampleSize}. + * + *

Used to implement + * {@link #globally(int) ApproximatUnique.globally(...)} and + * {@link #perKey(int) ApproximatUnique.perKey(...)}. + * + * @param the type of the values being combined + */ + public static class ApproximateUniqueCombineFn extends + CombineFn { + + /** + * The size of the space of hashes returned by the hash function. + */ + static final double HASH_SPACE_SIZE = + Long.MAX_VALUE - (double) Long.MIN_VALUE; + + /** + * A heap utility class to efficiently track the largest added elements. + */ + public static class LargestUnique implements Serializable { + private PriorityQueue heap = new PriorityQueue<>(); + private final long sampleSize; + + /** + * Creates a heap to track the largest {@code sampleSize} elements. + * + * @param sampleSize the size of the heap + */ + public LargestUnique(long sampleSize) { + this.sampleSize = sampleSize; + } + + /** + * Adds a value to the heap, returning whether the value is (large enough + * to be) in the heap. + */ + public boolean add(Long value) { + if (heap.contains(value)) { + return true; + } else if (heap.size() < sampleSize) { + heap.add(value); + return true; + } else if (value > heap.element()) { + heap.remove(); + heap.add(value); + return true; + } else { + return false; + } + } + + /** + * Returns the values in the heap, ordered largest to smallest. + */ + public List extractOrderedList() { + // The only way to extract the order from the heap is element-by-element + // from smallest to largest. + Long[] array = new Long[heap.size()]; + for (int i = heap.size() - 1; i >= 0; i--) { + array[i] = heap.remove(); + } + return Arrays.asList(array); + } + } + + private final long sampleSize; + private final Coder coder; + + public ApproximateUniqueCombineFn(long sampleSize, Coder coder) { + this.sampleSize = sampleSize; + this.coder = coder; + } + + @Override + public LargestUnique createAccumulator() { + return new LargestUnique(sampleSize); + } + + @Override + public void addInput(LargestUnique heap, T input) { + try { + heap.add(hash(input, coder)); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @Override + public LargestUnique mergeAccumulators(Iterable heaps) { + Iterator iterator = heaps.iterator(); + LargestUnique heap = iterator.next(); + while (iterator.hasNext()) { + List largestHashes = iterator.next().extractOrderedList(); + for (long hash : largestHashes) { + if (!heap.add(hash)) { + break; // The remainder of this list is all smaller. + } + } + } + return heap; + } + + @Override + public Long extractOutput(LargestUnique heap) { + List largestHashes = heap.extractOrderedList(); + if (largestHashes.size() < sampleSize) { + return (long) largestHashes.size(); + } else { + long smallestSampleHash = largestHashes.get(largestHashes.size() - 1); + double sampleSpaceSize = Long.MAX_VALUE - (double) smallestSampleHash; + // This formula takes into account the possibility of hash collisions, + // which become more likely than not for 2^32 distinct elements. + // Note that log(1+x) ~ x for small x, so for sampleSize << maxHash + // log(1 - sampleSize/sampleSpace) / log(1 - 1/sampleSpace) ~ sampleSize + // and hence estimate ~ sampleSize * HASH_SPACE_SIZE / sampleSpace + // as one would expect. + double estimate = Math.log1p(-sampleSize / sampleSpaceSize) + / Math.log1p(-1 / sampleSpaceSize) + * HASH_SPACE_SIZE / sampleSpaceSize; + return Math.round(estimate); + } + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, + Coder inputCoder) { + return SerializableCoder.of(LargestUnique.class); + } + + /** + * Encodes the given element using the given coder and hashes the encoding. + */ + static long hash(T element, Coder coder) + throws CoderException, IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + coder.encode(element, baos, Context.OUTER); + return Hashing.murmur3_128().hashBytes(baos.toByteArray()).asLong(); + } + } + + /** + * Computes the sampleSize based on the desired estimation error. + * + * @param estimationError should be bounded by [0.01, 0.5] + * @return the sample size needed for the desired estimation error + */ + static long sampleSizeFromEstimationError(double estimationError) { + return Math.round(Math.ceil(4.0 / Math.pow(estimationError, 2.0))); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java new file mode 100644 index 000000000000..9b374665451e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java @@ -0,0 +1,1045 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * {@code PTransform}s for combining {@code PCollection} elements + * globally and per-key. + */ +public class Combine { + + /** + * Returns a {@link Globally Combine.Globally} {@code PTransform} + * that uses the given {@code SerializableFunction} to combine all + * the elements of the input {@code PCollection} into a singleton + * {@code PCollection} value. The types of the input elements and the + * output value must be the same. + * + *

If the input {@code PCollection} is empty, the ouput will contain a the + * default value of the combining function if the input is windowed into + * the {@link GlobalWindow}; otherwise, the output will be empty. Note: this + * behavior is subject to change. + * + *

See {@link Globally Combine.Globally} for more information. + */ + public static Globally globally( + SerializableFunction, V> combiner) { + return globally(SimpleCombineFn.of(combiner)); + } + + /** + * Returns a {@link Globally Combine.Globally} {@code PTransform} + * that uses the given {@code CombineFn} to combine all the elements + * of the input {@code PCollection} into a singleton {@code PCollection} + * value. The types of the input elements and the output value can + * differ. + * + * If the input {@code PCollection} is empty, the ouput will contain a the + * default value of the combining function if the input is windowed into + * the {@link GlobalWindow}; otherwise, the output will be empty. Note: this + * behavior is subject to change. + * + *

See {@link Globally Combine.Globally} for more information. + */ + public static Globally globally( + CombineFn fn) { + return new Globally<>(fn); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the values lists to + * produce a combined value, and then returns a {@code PCollection} + * of {@code KV}s mapping each distinct key to its combined value for each + * window. + * + *

Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + *

See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + SerializableFunction, V> fn) { + return perKey(Combine.SimpleCombineFn.of(fn)); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the values lists to + * produce a combined value, and then returns a {@code PCollection} + * of {@code KV}s mapping each distinct key to its combined value for each + * window. + * + *

Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + *

See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + CombineFn fn) { + return perKey(fn.asKeyedFn()); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the key/values-lists + * pairs to produce a combined value, and then returns a + * {@code PCollection} of {@code KV}s mapping each distinct key to + * its combined value for each window. + * + *

Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + *

See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + KeyedCombineFn fn) { + return new PerKey<>(fn); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code SerializableFunction} to combine all the values associated + * with a key, ignoring the key. The type of the input and + * output values must be the same. + * + *

Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

Note that {@link #perKey(SerializableFunction)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + SerializableFunction, V> fn) { + return groupedValues(SimpleCombineFn.of(fn)); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code CombineFn} to combine all the values associated with a + * key, ignoring the key. The types of the input and output values + * can differ. + * + *

Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

Note that {@link #perKey(CombineFn)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + CombineFn fn) { + return groupedValues(fn.asKeyedFn()); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code KeyedCombineFn} to combine all the values associated with + * each key. The combining function is provided the key. The types + * of the input and output values can differ. + * + *

Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

Note that {@link #perKey(KeyedCombineFn)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + KeyedCombineFn fn) { + return new GroupedValues<>(fn); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} specifies how to combine a + * collection of input values of type {@code VI} into a single + * output value of type {@code VO}. It does this via one or more + * intermediate mutable accumulator values of type {@code VA}. + * + *

The overall process to combine a collection of input + * {@code VI} values into a single output {@code VO} value is as + * follows: + * + *

    + * + *
  1. The input {@code VI} values are partitioned into one or more + * batches. + * + *
  2. For each batch, the {@link #createAccumulator} operation is + * invoked to create a fresh mutable accumulator value of type + * {@code VA}, initialized to represent the combination of zero + * values. + * + *
  3. For each input {@code VI} value in a batch, the + * {@link #addInput} operation is invoked to add the value to that + * batch's accumulator {@code VA} value. The accumulator may just + * record the new value (e.g., if {@code VA == List}, or may do + * work to represent the combination more compactly. + * + *
  4. The {@link #mergeAccumulators} operation is invoked to + * combine a collection of accumulator {@code VA} values into a + * single combined output accumulator {@code VA} value, once the + * merging accumulators have had all all the input values in their + * batches added to them. This operation is invoked repeatedly, + * until there is only one accumulator value left. + * + *
  5. The {@link #extractOutput} operation is invoked on the final + * accumulator {@code VA} value to get the output {@code VO} value. + * + *
+ * + *

For example: + *

 {@code
+   * public class AverageFn extends CombineFn {
+   *   public static class Accum {
+   *     int sum = 0;
+   *     int count = 0;
+   *   }
+   *   public Accum createAccumulator() { return new Accum(); }
+   *   public void addInput(Accum accum, Integer input) {
+   *       accum.sum += input;
+   *       accum.count++;
+   *   }
+   *   public Accum mergeAccumulators(Iterable accums) {
+   *     Accum merged = createAccumulator();
+   *     for (Accum accum : accums) {
+   *       merged.sum += accum.sum;
+   *       merged.count += accum.count;
+   *     }
+   *     return merged;
+   *   }
+   *   public Double extractOutput(Accum accum) {
+   *     return ((double) accum.sum) / accum.count;
+   *   }
+   * }
+   * PCollection pc = ...;
+   * PCollection average = pc.apply(Combine.globally(new AverageFn()));
+   * } 
+ * + *

Combining functions used by {@link Combine.Globally}, + * {@link Combine.PerKey}, {@link Combine.GroupedValues}, and + * {@code PTransforms} derived from them should be + * associative and commutative. Associativity is + * required because input values are first broken up into subgroups + * before being combined, and their intermediate results further + * combined, in an arbitrary tree structure. Commutativity is + * required because any order of the input values is ignored when + * breaking up input values into groups. + * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class CombineFn implements Serializable { + /** + * Returns a new, mutable accumulator value, representing the + * accumulation of zero input values. + */ + public abstract VA createAccumulator(); + + /** + * Adds the given input value to the given accumulator, + * modifying the accumulator. + */ + public abstract void addInput(VA accumulator, VI input); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

May modify any of the argument accumulators. May return a + * fresh accumulator, or may return one of the (modified) argument + * accumulators. + */ + public abstract VA mergeAccumulators(Iterable accumulators); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + */ + public abstract VO extractOutput(VA accumulator); + + /** + * Applies this {@code CombineFn} to a collection of input values + * to produce a combined output value. + * + *

Useful when testing the behavior of a {@code CombineFn} + * separately from a {@code Combine} transform. + */ + public VO apply(Iterable inputs) { + VA accum = createAccumulator(); + for (VI input : inputs) { + addInput(accum, input); + } + return extractOutput(accum); + } + + /** + * Returns the {@code Coder} to use for accumulator {@code VA} + * values, or null if it is not able to be inferred. + * + *

By default, uses the knowledge of the {@code Coder} being used + * for {@code VI} values and the enclosing {@code Pipeline}'s + * {@code CoderRegistry} to try to infer the Coder for {@code VA} + * values. + */ + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return registry.getDefaultCoder( + getClass(), + CombineFn.class, + ImmutableMap.of("VI", inputCoder), + "VA"); + } + + /** + * Returns the {@code Coder} to use by default for output + * {@code VO} values, or null if it is not able to be inferred. + * + *

By default, uses the knowledge of the {@code Coder} being + * used for input {@code VI} values and the enclosing + * {@code Pipeline}'s {@code CoderRegistry} to try to infer the + * Coder for {@code VO} values. + */ + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) { + return registry.getDefaultCoder( + getClass(), + CombineFn.class, + ImmutableMap.of("VI", inputCoder, + "VA", getAccumulatorCoder(registry, inputCoder)), + "VO"); + } + + /** + * Converts this {@code CombineFn} into an equivalent + * {@link KeyedCombineFn}, which ignores the keys passed to it and + * combines the values according to this {@code CombineFn}. + * + * @param the type of the (ignored) keys + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public KeyedCombineFn asKeyedFn() { + // The key, an object, is never even looked at. + return new KeyedCombineFn() { + @Override + public VA createAccumulator(K key) { + return CombineFn.this.createAccumulator(); + } + + @Override + public void addInput(K key, VA accumulator, VI input) { + CombineFn.this.addInput(accumulator, input); + } + + @Override + public VA mergeAccumulators(K key, Iterable accumulators) { + return CombineFn.this.mergeAccumulators(accumulators); + } + + @Override + public VO extractOutput(K key, VA accumulator) { + return CombineFn.this.extractOutput(accumulator); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return CombineFn.this.getAccumulatorCoder(registry, inputCoder); + } + + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return CombineFn.this.getDefaultOutputCoder(registry, inputCoder); + } + }; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} that uses a subclass of + * {@link AccumulatingCombineFn.Accumulator} as its accumulator + * type. By defining the operations of the {@code Accumulator} + * helper class, the operations of the enclosing {@code CombineFn} + * are automatically provided. This can reduce the code required to + * implement a {@code CombineFn}. + * + *

For example, the example from {@link CombineFn} above can be + * expressed using {@code AccumulatingCombineFn} more concisely as + * follows: + * + *

 {@code
+   * public class AverageFn
+   *     extends AccumulatingCombineFn {
+   *   public Accum createAccumulator() { return new Accum(); }
+   *   public class Accum
+   *       extends AccumulatingCombineFn
+   *               .Accumulator {
+   *     private int sum = 0;
+   *     private int count = 0;
+   *     public void addInput(Integer input) {
+   *       sum += input;
+   *       count++;
+   *     }
+   *     public void mergeAccumulator(Accum other) {
+   *       sum += other.sum;
+   *       count += other.count;
+   *     }
+   *     public Double extractOutput() {
+   *       return ((double) sum) / count;
+   *     }
+   *   }
+   * }
+   * PCollection pc = ...;
+   * PCollection average = pc.apply(Combine.globally(new AverageFn()));
+   * } 
+ * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class AccumulatingCombineFn + .Accumulator, VO> + extends CombineFn { + + /** + * The type of mutable accumulator values used by this + * {@code AccumulatingCombineFn}. + */ + public abstract class Accumulator implements Serializable { + /** + * Adds the given input value to this accumulator, modifying + * this accumulator. + */ + public abstract void addInput(VI input); + + /** + * Adds the input values represented by the given accumulator + * into this accumulator. + */ + public abstract void mergeAccumulator(VA other); + + /** + * Returns the output value that is the result of combining all + * the input values represented by this accumulator. + */ + public abstract VO extractOutput(); + } + + @Override + public final void addInput(VA accumulator, VI input) { + accumulator.addInput(input); + } + + @Override + public final VA mergeAccumulators(Iterable accumulators) { + VA accumulator = createAccumulator(); + for (VA partial : accumulators) { + accumulator.mergeAccumulator(partial); + } + return accumulator; + } + + @Override + public final VO extractOutput(VA accumulator) { + return accumulator.extractOutput(); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + + /** + * A {@code KeyedCombineFn} specifies how to combine + * a collection of input values of type {@code VI}, associated with + * a key of type {@code K}, into a single output value of type + * {@code VO}. It does this via one or more intermediate mutable + * accumulator values of type {@code VA}. + * + *

The overall process to combine a collection of input + * {@code VI} values associated with an input {@code K} key into a + * single output {@code VO} value is as follows: + * + *

    + * + *
  1. The input {@code VI} values are partitioned into one or more + * batches. + * + *
  2. For each batch, the {@link #createAccumulator} operation is + * invoked to create a fresh mutable accumulator value of type + * {@code VA}, initialized to represent the combination of zero + * values. + * + *
  3. For each input {@code VI} value in a batch, the + * {@link #addInput} operation is invoked to add the value to that + * batch's accumulator {@code VA} value. The accumulator may just + * record the new value (e.g., if {@code VA == List}, or may do + * work to represent the combination more compactly. + * + *
  4. The {@link #mergeAccumulators} operation is invoked to + * combine a collection of accumulator {@code VA} values into a + * single combined output accumulator {@code VA} value, once the + * merging accumulators have had all all the input values in their + * batches added to them. This operation is invoked repeatedly, + * until there is only one accumulator value left. + * + *
  5. The {@link #extractOutput} operation is invoked on the final + * accumulator {@code VA} value to get the output {@code VO} value. + * + *
+ * + * All of these operations are passed the {@code K} key that the + * values being combined are associated with. + * + *

For example: + *

 {@code
+   * public class ConcatFn
+   *     extends KeyedCombineFn {
+   *   public static class Accum {
+   *     String s = "";
+   *   }
+   *   public Accum createAccumulator(String key) { return new Accum(); }
+   *   public void addInput(String key, Accum accum, Integer input) {
+   *       accum.s += "+" + input;
+   *   }
+   *   public Accum mergeAccumulators(String key, Iterable accums) {
+   *     Accum merged = new Accum();
+   *     for (Accum accum : accums) {
+   *       merged.s += accum.s;
+   *     }
+   *     return merged;
+   *   }
+   *   public String extractOutput(String key, Accum accum) {
+   *     return key + accum.s;
+   *   }
+   * }
+   * PCollection> pc = ...;
+   * PCollection> pc2 = pc.apply(
+   *     Combine.perKey(new ConcatFn()));
+   * } 
+ * + *

Keyed combining functions used by {@link Combine.PerKey}, + * {@link Combine.GroupedValues}, and {@code PTransforms} derived + * from them should be associative and commutative. + * Associativity is required because input values are first broken + * up into subgroups before being combined, and their intermediate + * results further combined, in an arbitrary tree structure. + * Commutativity is required because any order of the input values + * is ignored when breaking up input values into groups. + * + * @param type of keys + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class KeyedCombineFn + implements Serializable { + /** + * Returns a new, mutable accumulator value representing the + * accumulation of zero input values. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract VA createAccumulator(K key); + + /** + * Adds the given input value to the given accumulator, + * modifying the accumulator. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract void addInput(K key, VA accumulator, VI value); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

May modify any of the argument accumulators. May return a + * fresh accumulator, or may return one of the (modified) argument + * accumulators. + * + * @param key the key that all the accumulators are associated + * with + */ + public abstract VA mergeAccumulators(K key, Iterable accumulators); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract VO extractOutput(K key, VA accumulator); + + /** + * Applies this {@code KeyedCombineFn} to a key and a collection + * of input values to produce a combined output value. + * + *

Useful when testing the behavior of a {@code KeyedCombineFn} + * separately from a {@code Combine} transform. + */ + public VO apply(K key, Iterable inputs) { + VA accum = createAccumulator(key); + for (VI input : inputs) { + addInput(key, accum, input); + } + return extractOutput(key, accum); + } + + /** + * Returns the {@code Coder} to use for accumulator {@code VA} + * values, or null if it is not able to be inferred. + * + *

By default, uses the knowledge of the {@code Coder} being + * used for {@code K} keys and input {@code VI} values and the + * enclosing {@code Pipeline}'s {@code CoderRegistry} to try to + * infer the Coder for {@code VA} values. + */ + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return registry.getDefaultCoder( + getClass(), + KeyedCombineFn.class, + ImmutableMap.of("K", keyCoder, "VI", inputCoder), + "VA"); + } + + /** + * Returns the {@code Coder} to use by default for output + * {@code VO} values, or null if it is not able to be inferred. + * + *

By default, uses the knowledge of the {@code Coder} being + * used for {@code K} keys and input {@code VI} values and the + * enclosing {@code Pipeline}'s {@code CoderRegistry} to try to + * infer the Coder for {@code VO} values. + */ + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return registry.getDefaultCoder( + getClass(), + KeyedCombineFn.class, + ImmutableMap.of( + "K", keyCoder, + "VI", inputCoder, + "VA", getAccumulatorCoder(registry, keyCoder, inputCoder)), + "VO"); + } + } + + + //////////////////////////////////////////////////////////////////////////// + + /** + * {@code Combine.Globally} takes a {@code PCollection} + * and returns a {@code PCollection} whose single element is the result of + * combining all the elements of the input {@code PCollection}, + * using a specified + * {@link CombineFn CombineFn}. It is common + * for {@code VI == VO}, but not required. Common combining + * functions include sums, mins, maxes, and averages of numbers, + * conjunctions and disjunctions of booleans, statistical + * aggregations, etc. + * + *

Example of use: + *

 {@code
+   * PCollection pc = ...;
+   * PCollection sum = pc.apply(
+   *     Combine.globally(new Sum.SumIntegerFn()));
+   * } 
+ * + *

Combining can happen in parallel, with different subsets of the + * input {@code PCollection} being combined separately, and their + * intermediate results combined further, in an arbitrary tree + * reduction pattern, until a single result value is produced. + * + *

By default, the {@code Coder} of the output {@code PValue} + * is inferred from the concrete type of the + * {@code CombineFn}'s output type {@code VO}. + * + *

See also {@link #perKey}/{@link PerKey Combine.PerKey} and + * {@link #groupedValues}/{@link GroupedValues Combine.GroupedValues}, + * which are useful for combining values associated with each key in + * a {@code PCollection} of {@code KV}s. + * + * @param type of input values + * @param type of output values + */ + public static class Globally + extends PTransform, PCollection> { + + private final CombineFn fn; + + private Globally(CombineFn fn) { + this.fn = fn; + } + + @Override + public PCollection apply(PCollection input) { + PCollection output = input + .apply(WithKeys.of((Void) null)) + .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())) + .apply(Combine.perKey(fn.asKeyedFn())) + .apply(Values.create()); + + if (input.getWindowingFn().isCompatible(new GlobalWindow())) { + return insertDefaultValueIfEmpty(output); + } else { + return output; + } + } + + private PCollection insertDefaultValueIfEmpty(PCollection maybeEmpty) { + final PCollectionView, ?> maybeEmptyView = maybeEmpty.apply( + View.asIterable()); + return maybeEmpty.getPipeline() + .apply(Create.of((Void) null)).setCoder(VoidCoder.of()) + .apply(ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) { + Iterator combined = c.sideInput(maybeEmptyView).iterator(); + if (combined.hasNext()) { + c.output(combined.next()); + } else { + c.output(fn.apply(Collections.emptyList())); + } + } + }).withSideInputs(maybeEmptyView)) + .setCoder(maybeEmpty.getCoder()); + } + + @Override + protected String getKindString() { + return "Combine.Globally"; + } + } + + /** + * Converts a {@link SerializableFunction} from {@code Iterable}s + * to {@code V}s into a simple {@link CombineFn} over {@code V}s. + * + *

Used in the implementation of convenience methods like + * {@link #globally(SerializableFunction)}, + * {@link #perKey(SerializableFunction)}, and + * {@link #groupedValues(SerializableFunction)}. + */ + static class SimpleCombineFn extends CombineFn, V> { + /** + * Returns a {@code CombineFn} that uses the given + * {@code SerializableFunction} to combine values. + */ + public static SimpleCombineFn of( + SerializableFunction, V> combiner) { + return new SimpleCombineFn<>(combiner); + } + + /** + * The number of values to accumulate before invoking the combiner + * function to combine them. + */ + private static final int BUFFER_SIZE = 20; + + /** The combiner function. */ + private final SerializableFunction, V> combiner; + + private SimpleCombineFn(SerializableFunction, V> combiner) { + this.combiner = combiner; + } + + @Override + public List createAccumulator() { + return new ArrayList<>(); + } + + @Override + public void addInput(List accumulator, V input) { + accumulator.add(input); + if (accumulator.size() > BUFFER_SIZE) { + V combined = combiner.apply(accumulator); + accumulator.clear(); + accumulator.add(combined); + } + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + List singleton = new ArrayList<>(); + singleton.add(combiner.apply(Iterables.concat(accumulators))); + return singleton; + } + + @Override + public V extractOutput(List accumulator) { + return combiner.apply(accumulator); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code PerKey} takes a + * {@code PCollection>}, groups it by key, applies a + * combining function to the {@code VI} values associated with each + * key to produce a combined {@code VO} value, and returns a + * {@code PCollection>} representing a map from each + * distinct key of the input {@code PCollection} to the corresponding + * combined value. {@code VI} and {@code VO} are often the same. + * + *

This is a concise shorthand for an application of + * {@link GroupByKey} followed by an application of + * {@link GroupedValues Combine.GroupedValues}. See those + * operations for more details on how keys are compared for equality + * and on the default {@code Coder} for the output. + * + *

Example of use: + *

 {@code
+   * PCollection> salesRecords = ...;
+   * PCollection> totalSalesPerPerson =
+   *     salesRecords.apply(Combine.perKey(
+   *         new Sum.SumDoubleFn()));
+   * } 
+ * + *

Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * @param the type of the values of the output {@code PCollection} + */ + public static class PerKey + extends PTransform>, PCollection>> { + + private final transient KeyedCombineFn fn; + + private PerKey( + KeyedCombineFn fn) { + this.fn = fn; + } + + @Override + public PCollection> apply(PCollection> input) { + return input + .apply(GroupByKey.create()) + .apply(Combine.groupedValues(fn)); + } + + @Override + protected String getKindString() { + return "Combine.PerKey"; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code GroupedValues} takes a + * {@code PCollection>>}, such as the result of + * {@link GroupByKey}, applies a specified + * {@link KeyedCombineFn KeyedCombineFn} + * to each of the input {@code KV>} elements to + * produce a combined output {@code KV} element, and returns a + * {@code PCollection>} containing all the combined output + * elements. It is common for {@code VI == VO}, but not required. + * Common combining functions include sums, mins, maxes, and averages + * of numbers, conjunctions and disjunctions of booleans, statistical + * aggregations, etc. + * + *

Example of use: + *

 {@code
+   * PCollection> pc = ...;
+   * PCollection>> groupedByKey = pc.apply(
+   *     new GroupByKey());
+   * PCollection> sumByKey = groupedByKey.apply(
+   *     Combine.groupedValues(
+   *         new Sum.SumIntegerFn()));
+   * } 
+ * + *

See also {@link #perKey}/{@link PerKey Combine.PerKey} + * which captures the common pattern of "combining by key" in a + * single easy-to-use {@code PTransform}. + * + *

Combining for different keys can happen in parallel. Moreover, + * combining of the {@code Iterable} values associated a single + * key can happen in parallel, with different subsets of the values + * being combined separately, and their intermediate results combined + * further, in an arbitrary tree reduction pattern, until a single + * result value is produced for each key. + * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection>} is that of the keys of the input + * {@code PCollection>}, and the {@code Coder} of the values + * of the output {@code PCollection>} is inferred from the + * concrete type of the {@code KeyedCombineFn}'s output + * type {@code VO}. + * + *

Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See also {@link #globally}/{@link Globally Combine.Globally}, + * which combines all the values in a {@code PCollection} into a + * single value in a {@code PCollection}. + * + * @param type of input and output keys + * @param type of input values + * @param type of output values + */ + public static class GroupedValues + extends PTransform + >>, + PCollection>> { + + private final KeyedCombineFn fn; + + private GroupedValues(KeyedCombineFn fn) { + this.fn = fn; + } + + /** + * Returns the KeyedCombineFn used by this Combine operation. + */ + public KeyedCombineFn getFn() { + return fn; + } + + @Override + public PCollection> apply( + PCollection>> input) { + Coder> outputCoder = getDefaultOutputCoder(); + return input.apply(ParDo.of( + new DoFn>, KV>() { + @Override + public void processElement(ProcessContext c) { + K key = c.element().getKey(); + c.output(KV.of(key, fn.apply(key, c.element().getValue()))); + } + })).setCoder(outputCoder); + } + + private KvCoder getKvCoder() { + Coder>> inputCoder = + getInput().getCoder(); + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "Combine.GroupedValues requires its input to use KvCoder"); + } + @SuppressWarnings({"unchecked", "rawtypes"}) + KvCoder> kvCoder = (KvCoder) inputCoder; + Coder keyCoder = kvCoder.getKeyCoder(); + Coder> kvValueCoder = kvCoder.getValueCoder(); + if (!(kvValueCoder instanceof IterableCoder)) { + throw new IllegalStateException( + "Combine.GroupedValues requires its input values to use " + + "IterableCoder"); + } + IterableCoder inputValuesCoder = (IterableCoder) kvValueCoder; + Coder inputValueCoder = inputValuesCoder.getElemCoder(); + return KvCoder.of(keyCoder, inputValueCoder); + } + + @SuppressWarnings("unchecked") + public Coder getAccumulatorCoder() { + KvCoder kvCoder = getKvCoder(); + return ((KeyedCombineFn) fn).getAccumulatorCoder( + getCoderRegistry(), kvCoder.getKeyCoder(), kvCoder.getValueCoder()); + } + + @Override + public Coder> getDefaultOutputCoder() { + KvCoder kvCoder = getKvCoder(); + @SuppressWarnings("unchecked") + Coder outputValueCoder = ((KeyedCombineFn) fn) + .getDefaultOutputCoder( + getCoderRegistry(), kvCoder.getKeyCoder(), kvCoder.getValueCoder()); + return KvCoder.of(kvCoder.getKeyCoder(), outputValueCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java new file mode 100644 index 000000000000..1303b0a98634 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code Count} takes a {@code PCollection} and returns a + * {@code PCollection>} representing a map from each + * distinct element of the input {@code PCollection} to the number of times + * that element occurs in the input. Each of the keys in the output + * {@code PCollection} is unique. + * + *

Two values of type {@code T} are compared for equality not by + * regular Java {@link Object#equals}, but instead by first encoding + * each of the elements using the {@code PCollection}'s {@code Coder}, and then + * comparing the encoded bytes. This admits efficient parallel + * evaluation. + * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as the {@code Coder} of the + * elements of the input {@code PCollection}. + * + *

Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + *

Example of use: + *

 {@code
+ * PCollection words = ...;
+ * PCollection> wordCounts =
+ *     words.apply(Count.create());
+ * } 
+ */ +public class Count { + + /** + * Returns a {@link Globally Count.Globally} {@link PTransform} + * that counts the number of elements in its input {@link PCollection}. + * + *

See {@link Globally Count.Globally} for more details. + */ + public static Globally globally() { + return new Globally<>(); + } + + /** + * Returns a {@link PerElement Count.PerElement} {@link PTransform} + * that counts the number of occurrences of each element in its + * input {@link PCollection}. + * + *

See {@link PerElement Count.PerElement} for more details. + */ + public static PerElement perElement() { + return new PerElement<>(); + } + + /////////////////////////////////////// + + /** + * {@code Count.Globally} takes a {@code PCollection} and returns a + * {@code PCollection} containing a single element which is the total + * number of elements in the {@code PCollection}. + * + *

Example of use: + *

 {@code
+   * PCollection words = ...;
+   * PCollection wordCount =
+   *     words.apply(Count.globally());
+   * } 
+ * + * @param the type of the elements of the input {@code PCollection} + */ + public static class Globally + extends PTransform, PCollection> { + + public Globally() { } + + @Override + public PCollection apply(PCollection input) { + return + input + .apply(ParDo.named("Init") + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(1L); + } + })) + .apply(Sum.longsGlobally()); + } + } + + /** + * {@code Count.PerElement} takes a {@code PCollection} and returns a + * {@code PCollection>} representing a map from each + * distinct element of the input {@code PCollection} to the number of times + * that element occurs in the input. Each of the keys in the output + * {@code PCollection} is unique. + * + *

This transform compares two values of type {@code T} by first + * encoding each element using the input {@code PCollection}'s + * {@code Coder}, then comparing the encoded bytes. Because of this, + * the input coder must be deterministic. (See + * {@link com.google.cloud.dataflow.sdk.coders.Coder#isDeterministic()} for more detail). + * Performing the comparison in this manner admits efficient parallel evaluation. + * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as the {@code Coder} of the + * elements of the input {@code PCollection}. + * + *

Example of use: + *

 {@code
+   * PCollection words = ...;
+   * PCollection> wordCounts =
+   *     words.apply(Count.perElement());
+   * } 
+ * + * @param the type of the elements of the input {@code PCollection}, and + * the type of the keys of the output {@code PCollection} + */ + public static class PerElement + extends PTransform, PCollection>> { + + public PerElement() { } + + @Override + public PCollection> apply(PCollection input) { + return + input + .apply(ParDo.named("Init") + .of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), 1L)); + } + })) + .apply(Sum.longsPerKey()); + } + + @Override + public String getKindString() { + return "Count.PerElement"; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java new file mode 100644 index 000000000000..93747ea6462f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java @@ -0,0 +1,314 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.reflect.TypeToken; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * {@code Create} takes a collection of elements of type {@code T} + * known when the pipeline is constructed and returns a + * {@code PCollection} containing the elements. + * + *

Example of use: + *

 {@code
+ * Pipeline p = ...;
+ *
+ * PCollection pc = p.apply(Create.of(3, 4, 5)).setCoder(BigEndianIntegerCoder.of());
+ *
+ * Map map = ...;
+ * PCollection> pt =
+ *     p.apply(Create.of(map))
+ *      .setCoder(KvCoder.of(StringUtf8Coder.of(),
+ *                           BigEndianIntegerCoder.of()));
+ * } 
+ * + *

Note that {@link PCollection#setCoder} must be called + * explicitly to set the encoding of the resulting + * {@code PCollection}, since {@code Create} does not infer the + * encoding. + * + *

A good use for {@code Create} is when a {@code PCollection} + * needs to be created without dependencies on files or other external + * entities. This is especially useful during testing. + * + *

Caveat: {@code Create} only supports small in-memory datasets, + * particularly when submitting jobs to the Google Cloud Dataflow + * service. + * + *

{@code Create} can automatically determine the {@code Coder} to use + * if all elements are the same type, and a default exists for that type. + * See {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry} for details + * on how defaults are determined. + * + * @param the type of the elements of the resulting {@code PCollection} + */ +public class Create extends PTransform> { + + /** + * Returns a new {@code Create} root transform that produces a + * {@link PCollection} containing the specified elements. + * + *

The argument should not be modified after this is called. + * + *

The elements will have a timestamp of negative infinity, see + * {@link Create#timestamped} for a way of creating a {@code PCollection} + * with timestamped elements. + * + *

The result of applying this transform should have its + * {@link Coder} specified explicitly, via a call to + * {@link PCollection#setCoder}. + */ + public static Create of(Iterable elems) { + return new Create<>(elems); + } + + /** + * Returns a new {@code Create} root transform that produces a + * {@link PCollection} containing the specified elements. + * + *

The elements will have a timestamp of negative infinity, see + * {@link Create#timestamped} for a way of creating a {@code PCollection} + * with timestamped elements. + * + *

The argument should not be modified after this is called. + * + *

The result of applying this transform should have its + * {@link Coder} specified explicitly, via a call to + * {@link PCollection#setCoder}. + */ + public static Create of(T... elems) { + return of(Arrays.asList(elems)); + } + + /** + * Returns a new {@code Create} root transform that produces a + * {@link PCollection} of {@link KV}s corresponding to the keys and + * values of the specified {@code Map}. + * + *

The elements will have a timestamp of negative infinity, see + * {@link Create#timestamped} for a way of creating a {@code PCollection} + * with timestamped elements. + * + *

The result of applying this transform should have its + * {@link Coder} specified explicitly, via a call to + * {@link PCollection#setCoder}. + */ + public static Create> of(Map elems) { + List> kvs = new ArrayList<>(elems.size()); + for (Map.Entry entry : elems.entrySet()) { + kvs.add(KV.of(entry.getKey(), entry.getValue())); + } + return of(kvs); + } + + /** + * Returns a new root transform that produces a {@link PCollection} containing + * the specified elements with the specified timestamps. + * + *

The argument should not be modified after this is called. + */ + public static CreateTimestamped timestamped(Iterable> elems) { + return new CreateTimestamped<>(elems); + } + + /** + * Returns a new root transform that produces a {@link PCollection} containing + * the specified elements with the specified timestamps. + * + *

The argument should not be modified after this is called. + */ + public static CreateTimestamped timestamped(TimestampedValue... elems) { + return new CreateTimestamped(Arrays.asList(elems)); + } + + /** + * Returns a new root transform that produces a {@link PCollection} containing + * the specified elements with the specified timestamps. + * + *

The arguments should not be modified after this is called. + * + * @throws IllegalArgumentException if there are a different number of values + * and timestamps + */ + public static CreateTimestamped timestamped( + Iterable values, Iterable timestamps) { + List> elems = new ArrayList<>(); + Iterator valueIter = values.iterator(); + Iterator timestampIter = timestamps.iterator(); + while (valueIter.hasNext() && timestampIter.hasNext()) { + elems.add(TimestampedValue.of(valueIter.next(), new Instant(timestampIter.next()))); + } + Preconditions.checkArgument( + !valueIter.hasNext() && !timestampIter.hasNext(), + "Expect sizes of values and timestamps are same."); + return new CreateTimestamped<>(elems); + } + + @Override + public PCollection apply(PInput input) { + return PCollection.createPrimitiveOutputInternal(new GlobalWindow()); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** The elements of the resulting PCollection. */ + private final Iterable elems; + + /** + * Constructs a {@code Create} transform that produces a + * {@link PCollection} containing the specified elements. + * + *

The argument should not be modified after this is called. + */ + private Create(Iterable elems) { + this.elems = elems; + } + + public Iterable getElements() { + return elems; + } + + @Override + protected Coder getDefaultOutputCoder() { + // First try to deduce a coder using the types of the elements. + Class elementType = null; + for (T elem : elems) { + Class type = elem.getClass(); + if (elementType == null) { + elementType = type; + } else if (!elementType.equals(type)) { + // Elements are not the same type, require a user-specified coder. + elementType = null; + break; + } + } + if (elementType == null) { + return super.getDefaultOutputCoder(); + } + if (elementType.getTypeParameters().length == 0) { + Coder candidate = getCoderRegistry().getDefaultCoder(TypeToken.of(elementType)); + if (candidate != null) { + return candidate; + } + } + + // If that fails, try to deduce a coder using the elements themselves + Coder coder = null; + for (T elem : elems) { + Coder c = getCoderRegistry().getDefaultCoder(elem); + if (coder == null) { + coder = c; + } else if (!Objects.equals(c, coder)) { + coder = null; + break; + } + } + if (coder != null) { + return coder; + } + + return super.getDefaultOutputCoder(); + } + + /** + * A {@code PTransform} that creates a {@code PCollection} whose elements have + * associated timestamps. + */ + private static class CreateTimestamped extends PTransform> { + /** The timestamped elements of the resulting PCollection. */ + private final Iterable> elems; + + private CreateTimestamped(Iterable> elems) { + this.elems = elems; + } + + @Override + public PCollection apply(PBegin input) { + PCollection> intermediate = input.apply(Create.of(elems)); + if (!elems.iterator().hasNext()) { + // There aren't any elements, so we can provide a fake coder instance. + // If we don't set a Coder here, users of CreateTimestamped have + // no way to set the coder of the intermediate PCollection. + intermediate.setCoder((Coder) TimestampedValue.TimestampedValueCoder.of(VoidCoder.of())); + } + + return intermediate.apply(ParDo.of(new ConvertTimestamps())); + } + + private static class ConvertTimestamps extends DoFn, T> { + @Override + public void processElement(ProcessContext c) { + c.outputWithTimestamp(c.element().getValue(), c.element().getTimestamp()); + } + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Create.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Create transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + Create transform, + DirectPipelineRunner.EvaluationContext context) { + // Convert the Iterable of elems into a List of elems. + List listElems; + if (transform.elems instanceof Collection) { + Collection collectionElems = (Collection) transform.elems; + listElems = new ArrayList<>(collectionElems.size()); + } else { + listElems = new ArrayList<>(); + } + for (T elem : transform.elems) { + listElems.add( + context.ensureElementEncodable(transform.getOutput(), elem)); + } + context.setPCollection(transform.getOutput(), listElems); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java new file mode 100644 index 000000000000..3c61ab38557d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java @@ -0,0 +1,330 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; +import com.google.cloud.dataflow.sdk.values.CodedTupleTagMap; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.reflect.TypeToken; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Collection; +import java.util.List; + +/** + * The argument to {@link ParDo} providing the code to use to process + * elements of the input + * {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * + *

See {@link ParDo} for more explanation, examples of use, and + * discussion of constraints on {@code DoFn}s, including their + * serializability, lack of access to global shared mutable state, + * requirements for failure tolerance, and benefits of optimization. + * + *

{@code DoFn}s can be tested in the context of a particular + * {@code Pipeline} by running that {@code Pipeline} on sample input + * and then checking its output. Unit testing of a {@code DoFn}, + * separately from any {@code ParDo} transform or {@code Pipeline}, + * can be done via the {@link DoFnTester} harness. + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ +public abstract class DoFn implements Serializable { + + /** Information accessible to all methods in this {@code DoFn}. */ + public abstract class Context { + + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * invoking this {@code DoFn}. The {@code PipelineOptions} will + * be the default running via {@link DoFnTester}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Returns the value of the side input. + * + * @throws IllegalArgumentException if this is not a side input + * @see ParDo#withSideInput + */ + public abstract T sideInput(PCollectionView view); + + /** + * Adds the given element to the main output {@code PCollection}. + * + *

If invoked from {@link DoFn#processElement}, the output + * element will have the same timestamp and be in the same windows + * as the input element passed to {@link DoFn#processElement}). + * + *

Is is illegal to invoke this from {@link #startBundle} or + * {@link #finishBundle} unless the input {@code PCollection} is + * windowed by the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow}. + * If this is the case, the output element will have a timestamp + * of negative infinity and be in the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow}. + */ + public abstract void output(O output); + + /** + * Adds the given element to the main output {@code PCollection}, + * with the given timestamp. + * + *

If invoked from {@link DoFn#processElement}), the timestamp + * must not be older than the input element's timestamp minus + * {@link DoFn#getAllowedTimestampSkew}. The output element will + * be in the same windows as the input element. + * + *

Is is illegal to invoke this from {@link #startBundle} or + * {@link #finishBundle} unless the input {@code PCollection} is + * windowed by the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow}. + * If this is the case, the output element's timestamp will be + * the given timestamp and its window will be the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow}. + */ + public abstract void outputWithTimestamp(O output, Instant timestamp); + + /** + * Adds the given element to the side output {@code PCollection} with the + * given tag. + * + *

The caller of {@code ParDo} uses {@link ParDo#withOutputTags} to + * specify the tags of side outputs that it consumes. Non-consumed side + * outputs, e.g., outputs for monitoring purposes only, don't necessarily + * need to be specified. + * + *

The output element will have the same timestamp and be in the same + * windows as the input element passed to {@link DoFn#processElement}). + * + * @throws IllegalArgumentException if the number of outputs exceeds + * the limit of 1,000 outputs per DoFn + * @see ParDo#withOutputTags + */ + public abstract void sideOutput(TupleTag tag, T output); + + // TODO: add sideOutputWithTimestamp[AndWindows] + + /** + * Returns an aggregator with aggregation logic specified by the CombineFn + * argument. The name provided should be unique across aggregators created + * within the containing ParDo transform application. + * + *

All instances of this DoFn in the containing ParDo + * transform application should define aggregators consistently, + * i.e., an aggregator with a given name always specifies the same + * combiner in all DoFn instances in the containing ParDo + * transform application. + * + * @throws IllegalArgumentException if the given CombineFn is not + * supported as aggregator's combiner, or if the given name collides + * with another aggregator or system-provided counter. + */ + public abstract Aggregator createAggregator( + String name, Combine.CombineFn combiner); + + /** + * Returns an aggregator with aggregation logic specified by the + * SerializableFunction argument. The name provided should be unique across + * aggregators created within the containing ParDo transform application. + * + *

All instances of this DoFn in the containing ParDo + * transform application should define aggregators consistently, + * i.e., an aggregator with a given name always specifies the same + * combiner in all DoFn instances in the containing ParDo + * transform application. + * + * @throws IllegalArgumentException if the given SerializableFunction is + * not supported as aggregator's combiner, or if the given name collides + * with another aggregator or system-provided counter. + */ + public abstract Aggregator createAggregator( + String name, SerializableFunction, AO> combiner); + } + + /** + * Information accessible when running {@link DoFn#processElement}. + */ + public abstract class ProcessContext extends Context { + + /** + * Returns the input element to be processed. + */ + public abstract I element(); + + /** + * Returns this {@code DoFn}'s state associated with the input + * element's key. This state can be used by the {@code DoFn} to + * store whatever information it likes with that key. Unlike + * {@code DoFn} instance variables, this state is persistent and + * can be arbitrarily large; it is more expensive than instance + * variable state, however. It is particularly intended for + * streaming computations. + * + *

Requires that this {@code DoFn} implements + * {@link RequiresKeyedState}. + * + *

Each {@link ParDo} invocation with this {@code DoFn} as an + * argument will maintain its own {@code KeyedState} maps, one per + * key. + * + * @throws UnsupportedOperationException if this {@link DoFn} does + * not implement {@link RequiresKeyedState} + */ + public abstract KeyedState keyedState(); + + /** + * Returns the timestamp of the input element. + * + *

See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract Instant timestamp(); + + /** + * Returns the set of windows to which the input element has been assigned. + * + *

See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract Collection windows(); + } + + /** + * Returns the allowed timestamp skew duration, which is the maximum + * duration that timestamps can be shifted backward in + * {@link DoFn.Context#outputWithTimestamp}. + * + * The default value is {@code Duration.ZERO}, in which case + * timestamps can only be shifted forward to future. For infinite + * skew, return {@code Duration.millis(Long.MAX_VALUE)}. + */ + public Duration getAllowedTimestampSkew() { + return Duration.ZERO; + } + + /** + * Interface for signaling that a {@link DoFn} needs to maintain + * per-key state, accessed via + * {@link DoFn.ProcessContext#keyedState}. + * + *

This functionality is experimental and likely to change. + */ + public interface RequiresKeyedState {} + + /** + * Interface for interacting with keyed state. + * + *

This functionality is experimental and likely to change. + */ + public interface KeyedState { + /** + * Updates this {@code KeyedState} in place so that the given tag + * maps to the given value. + * + * @throws IOException if encoding the given value fails + */ + public void store(CodedTupleTag tag, T value) throws IOException; + + /** + * Returns the value associated with the given tag in this + * {@code KeyedState}, or {@code null} if the tag has no asssociated + * value. + * + *

See {@link #lookup(List)} to look up multiple tags at + * once. It is significantly more efficient to look up multiple + * tags all at once rather than one at a time. + * + * @throws IOException if decoding the requested value fails + */ + public T lookup(CodedTupleTag tag) throws IOException; + + /** + * Returns a map from the given tags to the values associated with + * those tags in this {@code KeyedState}. A tag will map to null if + * the tag had no associated value. + * + *

See {@link #lookup(CodedTupleTag)} to look up a single + * tag. + * + * @throws CoderException if decoding any of the requested values fails + */ + public CodedTupleTagMap lookup(List> tags) throws IOException; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Prepares this {@code DoFn} instance for processing a batch of elements. + * + *

By default, does nothing. + */ + public void startBundle(Context c) throws Exception { + } + + /** + * Processes an input element. + */ + public abstract void processElement(ProcessContext c) throws Exception; + + /** + * Finishes processing this batch of elements. This {@code DoFn} + * instance will be thrown away after this operation returns. + * + *

By default, does nothing. + */ + public void finishBundle(Context c) throws Exception { + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns a {@link TypeToken} capturing what is known statically + * about the input type of this {@code DoFn} instance's most-derived + * class. + * + *

See {@link #getOutputTypeToken} for more discussion. + */ + TypeToken getInputTypeToken() { + return new TypeToken(getClass()) {}; + } + + /** + * Returns a {@link TypeToken} capturing what is known statically + * about the output type of this {@code DoFn} instance's + * most-derived class. + * + *

In the normal case of a concrete {@code DoFn} subclass with + * no generic type parameters of its own (including anonymous inner + * classes), this will be a complete non-generic type, which is good + * for choosing a default output {@code Coder} for the output + * {@code PCollection}. + */ + TypeToken getOutputTypeToken() { + return new TypeToken(getClass()) {}; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java new file mode 100644 index 000000000000..3e23b5ed0450 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java @@ -0,0 +1,357 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A harness for unit-testing a {@link DoFn}. + * + *

For example: + * + *

 {@code
+ * DoFn fn = ...;
+ *
+ * DoFnTester fnTester = DoFnTester.of(fn);
+ *
+ * // Set arguments shared across all batches:
+ * fnTester.setSideInputs(...);      // If fn takes side inputs.
+ * fnTester.setSideOutputTags(...);  // If fn writes to side outputs.
+ *
+ * // Process a batch containing a single input element:
+ * Input testInput = ...;
+ * List testOutputs = fnTester.processBatch(testInput);
+ * Assert.assertThat(testOutputs,
+ *                   JUnitMatchers.hasItems(...));
+ *
+ * // Process a bigger batch:
+ * Assert.assertThat(fnTester.processBatch(i1, i2, ...),
+ *                   JUnitMatchers.hasItems(...));
+ * } 
+ * + * @param the type of the {@code DoFn}'s (main) input elements + * @param the type of the {@code DoFn}'s (main) output elements + */ +public class DoFnTester { + /** + * Returns a {@code DoFnTester} supporting unit-testing of the given + * {@link DoFn}. + */ + @SuppressWarnings("unchecked") + public static DoFnTester of(DoFn fn) { + return new DoFnTester(fn); + } + + /** + * Registers the tuple of values of the side input {@link PCollectionView}s to + * pass to the {@link DoFn} under test. + * + *

If needed, first creates a fresh instance of the {@link DoFn} + * under test. + * + *

If this isn't called, {@code DoFnTester} assumes the + * {@link DoFn} takes no side inputs. + */ + public void setSideInputs(Map, Iterable>> sideInputs) { + this.sideInputs = sideInputs; + resetState(); + } + + /** + * Registers the values of a side input {@link PCollectionView} to + * pass to the {@link DoFn} under test. + * + *

If needed, first creates a fresh instance of the {@code DoFn} + * under test. + * + *

If this isn't called, {@code DoFnTester} assumes the + * {@code DoFn} takes no side inputs. + */ + public void setSideInput(PCollectionView sideInput, Iterable> value) { + sideInputs.put(sideInput, value); + } + + /** + * Registers the values for a side input {@link PCollectionView} to + * pass to the {@link DoFn} under test. All values are placed + * in the global window. + */ + public void setSideInputInGlobalWindow( + PCollectionView sideInput, + Iterable value) { + sideInputs.put( + sideInput, + Iterables.transform(value, new Function>() { + @Override + public WindowedValue apply(Object input) { + return WindowedValue.valueInGlobalWindow(input); + } + })); + } + + + /** + * Registers the list of {@code TupleTag}s that can be used by the + * {@code DoFn} under test to output to side output + * {@code PCollection}s. + * + *

If needed, first creates a fresh instance of the DoFn under test. + * + *

If this isn't called, {@code DoFnTester} assumes the + * {@code DoFn} doesn't emit to any side outputs. + */ + public void setSideOutputTags(TupleTagList sideOutputTags) { + this.sideOutputTags = sideOutputTags.getAll(); + resetState(); + } + + /** + * A convenience operation that first calls {@link #startBundle}, + * then calls {@link #processElement} on each of the arguments, then + * calls {@link #finishBundle}, then returns the result of + * {@link #takeOutputElements}. + */ + public List processBatch(I... inputElements) { + startBundle(); + for (I inputElement : inputElements) { + processElement(inputElement); + } + finishBundle(); + return takeOutputElements(); + } + + /** + * Calls {@link DoFn#startBundle} on the {@code DoFn} under test. + * + *

If needed, first creates a fresh instance of the DoFn under test. + */ + public void startBundle() { + resetState(); + initializeState(); + fnRunner.startBundle(); + state = State.STARTED; + } + + /** + * Calls {@link DoFn#processElement} on the {@code DoFn} under test, in a + * context where {@link DoFn.ProcessContext#element} returns the + * given element. + * + *

Will call {@link #startBundle} automatically, if it hasn't + * already been called. + * + * @throws IllegalStateException if the {@code DoFn} under test has already + * been finished + */ + public void processElement(I element) { + if (state == State.FINISHED) { + throw new IllegalStateException("finishBundle() has already been called"); + } + if (state == State.UNSTARTED) { + startBundle(); + } + fnRunner.processElement(WindowedValue.valueInGlobalWindow(element)); + } + + /** + * Calls {@link DoFn#finishBundle} of the {@code DoFn} under test. + * + *

Will call {@link #startBundle} automatically, if it hasn't + * already been called. + * + * @throws IllegalStateException if the {@code DoFn} under test has already + * been finished + */ + public void finishBundle() { + if (state == State.FINISHED) { + throw new IllegalStateException("finishBundle() has already been called"); + } + if (state == State.UNSTARTED) { + startBundle(); + } + fnRunner.finishBundle(); + state = State.FINISHED; + } + + /** + * Returns the elements output so far to the main output. Does not + * clear them, so subsequent calls will continue to include these + * elements. + * + * @see #takeOutputElements + * @see #clearOutputElements + * + * TODO: provide accessors that take and return {@code WindowedValue}s + * in order to test timestamp- and window-sensitive DoFns. + */ + public List peekOutputElements() { + // TODO: Should we return an unmodifiable list? + return Lists.transform(fnRunner.getReceiver(mainOutputTag), + new Function() { + @Override + public O apply(Object input) { + return ((WindowedValue) input).getValue(); + } + }); + + } + + /** + * Clears the record of the elements output so far to the main output. + * + * @see #peekOutputElements + */ + public void clearOutputElements() { + peekOutputElements().clear(); + } + + /** + * Returns the elements output so far to the main output. + * Clears the list so these elements don't appear in future calls. + * + * @see #peekOutputElements + */ + public List takeOutputElements() { + List resultElems = new ArrayList<>(peekOutputElements()); + clearOutputElements(); + return resultElems; + } + + /** + * Returns the elements output so far to the side output with the + * given tag. Does not clear them, so subsequent calls will + * continue to include these elements. + * + * @see #takeSideOutputElements + * @see #clearSideOutputElements + */ + public List peekSideOutputElements(TupleTag tag) { + // TODO: Should we return an unmodifiable list? + return Lists.transform(fnRunner.getReceiver(tag), + new Function() { + @Override + public T apply(Object input) { + return ((WindowedValue) input).getValue(); + }}); + } + + /** + * Clears the record of the elements output so far to the side + * output with the given tag. + * + * @see #peekSideOutputElements + */ + public void clearSideOutputElements(TupleTag tag) { + peekSideOutputElements(tag).clear(); + } + + /** + * Returns the elements output so far to the side output with the given tag. + * Clears the list so these elements don't appear in future calls. + * + * @see #peekSideOutputElements + */ + public List takeSideOutputElements(TupleTag tag) { + List resultElems = new ArrayList<>(peekSideOutputElements(tag)); + clearSideOutputElements(tag); + return resultElems; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** The possible states of processing a DoFn. */ + enum State { UNSTARTED, STARTED, FINISHED } + + final PipelineOptions options = PipelineOptionsFactory.create(); + + /** The original DoFn under test. */ + final DoFn origFn; + + /** The side input values to provide to the DoFn under test. */ + private Map, Iterable>> sideInputs = + new HashMap<>(); + + /** The output tags used by the DoFn under test. */ + TupleTag mainOutputTag = new TupleTag<>(); + List> sideOutputTags = new ArrayList<>(); + + /** The original DoFn under test, if started. */ + DoFn fn; + + /** The DoFnRunner if processing is in progress. */ + DoFnRunner fnRunner; + + /** Counters for user-defined Aggregators if processing is in progress. */ + CounterSet counterSet; + // TODO: expose counterSet through a getter method, once we have + // a convenient public API for it. + + /** The state of processing of the DoFn under test. */ + State state; + + DoFnTester(DoFn origFn) { + this.origFn = origFn; + resetState(); + } + + void resetState() { + fn = null; + fnRunner = null; + counterSet = null; + state = State.UNSTARTED; + } + + @SuppressWarnings("unchecked") + void initializeState() { + fn = (DoFn) + SerializableUtils.deserializeFromByteArray( + SerializableUtils.serializeToByteArray(origFn), + origFn.toString()); + counterSet = new CounterSet(); + PTuple runnerSideInputs = PTuple.empty(); + for (Map.Entry, Iterable>> entry + : sideInputs.entrySet()) { + runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), entry.getValue()); + } + fnRunner = DoFnRunner.createWithListOutputs( + options, + fn, + runnerSideInputs, + mainOutputTag, + sideOutputTags, + (new BatchModeExecutionContext()).createStepContext("stepName"), + counterSet.getAddCounterMutator()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/First.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/First.java new file mode 100644 index 000000000000..9e4f3b099d48 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/First.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +/** + * {@code First} takes a {@code PCollection} and a limit, and + * produces a new {@code PCollection} containing up to limit + * elements of the input {@code PCollection}. + * + *

If the input and output {@code PCollection}s are ordered, then + * {@code First} will select the first elements, otherwise it will + * select any elements. + * + *

If limit is less than or equal to the size of the input + * {@code PCollection}, then all the input's elements will be selected. + * + *

All of the elements of the output {@code PCollection} should fit into + * main memory of a single worker machine. This operation does not + * run in parallel. + * + *

Example of use: + *

 {@code
+ * PCollection input = ...;
+ * PCollection output = input.apply(First.of(100));
+ * } 
+ * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ +public class First extends PTransform, PCollection> { + /** + * Returns a {@code First} {@code PTransform}. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + * @param limit the numer of elements to take from the input + */ + public static First of(long limit) { + return new First<>(limit); + } + + private final long limit; + + /** + * Constructs a {@code First} PTransform that, when applied, + * produces a new PCollection containing up to {@code limit} + * elements of its input {@code PCollection}. + */ + private First(long limit) { + this.limit = limit; + if (limit < 0) { + throw new IllegalArgumentException( + "limit argument to First should be non-negative"); + } + } + + private static class CopyFirstDoFn extends DoFn { + long limit; + final PCollectionView, ?> iterableView; + + public CopyFirstDoFn(long limit, PCollectionView, ?> iterableView) { + this.limit = limit; + this.iterableView = iterableView; + } + + @Override + public void processElement(ProcessContext c) { + for (T i : c.sideInput(iterableView)) { + if (limit-- <= 0) { + break; + } + c.output(i); + } + } + } + + @Override + public PCollection apply(PCollection in) { + PCollectionView, ?> iterableView = in.apply(View.asIterable()); + return + in.getPipeline() + .apply(Create.of((Void) null)).setCoder(VoidCoder.of()) + .apply(ParDo + .withSideInputs(iterableView) + .of(new CopyFirstDoFn<>(limit, iterableView))) + .setCoder(in.getCoder()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java new file mode 100644 index 000000000000..14b2169b97bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@code Flatten} takes multiple {@code PCollection}s bundled + * into a {@code PCollectionList} and returns a single + * {@code PCollection} containing all the elements in all the input + * {@code PCollection}s. The name "Flatten" suggests taking a list of + * lists and flattening them into a single list. + * + *

Example of use: + *

 {@code
+ * PCollection pc1 = ...;
+ * PCollection pc2 = ...;
+ * PCollection pc3 = ...;
+ * PCollectionList pcs = PCollectionList.of(pc1).and(pc2).and(pc3);
+ * PCollection merged = pcs.apply(Flatten..create());
+ * } 
+ * + *

By default, the {@code Coder} of the output {@code PCollection} + * is the same as the {@code Coder} of the first {@code PCollection} + * in the input {@code PCollectionList} (if the + * {@code PCollectionList} is non-empty). + * + */ +public class Flatten { + + /** + * Returns a {@link PTransform} that flattens a {@link PCollectionList} + * into a {@link PCollection} containing all the elements of all + * the {@link PCollection}s in its input. + * + *

If any of the inputs to {@code Flatten} require window merging, + * all inputs must have equal {@link WindowingFn}s. + * The output elements of {@code Flatten} are in the same windows and + * have the same timestamps as their corresponding input elements. The output + * {@code PCollection} will have the same + * {@link WindowingFn} as all of the inputs. + * + * @param the type of the elements in the input and output + * {@code PCollection}s. + */ + public static FlattenPCollectionList pCollections() { + return new FlattenPCollectionList<>(); + } + + @Deprecated + public static FlattenPCollectionList create() { + return pCollections(); + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection>} + * and returns a {@code PCollection} containing all the elements from + * all the {@code Iterable}s. + * + *

Example of use: + *

 {@code
+   * PCollection> pcOfIterables = ...;
+   * PCollection pc = pcOfIterables.apply(Flatten.iterables());
+   * } 
+ * + *

By default, the output {@code PCollection} encodes its elements + * using the same {@code Coder} that the input uses for + * the elements in its {@code Iterable}. + * + * @param the type of the elements of the input {@code Iterable} and + * the output {@code PCollection} + */ + public static FlattenIterables iterables() { + return new FlattenIterables<>(); + } + + /** + * A {@link PTransform} that flattens a {@link PCollectionList} + * into a {@link PCollection} containing all the elements of all + * the {@link PCollection}s in its input. + * + * @param the type of the elements in the input and output + * {@code PCollection}s. + */ + public static class FlattenPCollectionList + extends PTransform, PCollection> { + + private FlattenPCollectionList() { } + + @Override + public PCollection apply(PCollectionList inputs) { + WindowingFn windowingFn; + if (!getInput().getAll().isEmpty()) { + windowingFn = getInput().get(0).getWindowingFn(); + for (PCollection input : getInput().getAll()) { + if (!windowingFn.isCompatible(input.getWindowingFn())) { + throw new IllegalStateException( + "Inputs to Flatten had incompatible window windowingFns: " + + windowingFn + ", " + input.getWindowingFn()); + } + } + } else { + windowingFn = new GlobalWindow(); + } + + return PCollection.createPrimitiveOutputInternal(windowingFn); + } + + @Override + protected Coder getDefaultOutputCoder() { + List> inputs = getInput().getAll(); + if (inputs.isEmpty()) { + // Cannot infer a Coder from an empty list of input PCollections. + return null; + } + // Use the Coder of the first input. + return inputs.get(0).getCoder(); + } + + } + + /** + * {@code FlattenIterables} takes a {@code PCollection>} and returns a + * {@code PCollection} that contains all the elements from each iterable. + * Implements {@link #fromIterable}. + * + * @param the type of the elements of the input {@code Iterable}s and + * the output {@code PCollection} + */ + public static class FlattenIterables + extends PTransform>, PCollection> { + + @Override + public PCollection apply(PCollection> in) { + Coder> inCoder = in.getCoder(); + if (!(inCoder instanceof IterableCoder)) { + throw new IllegalArgumentException( + "expecting the input Coder to be an IterableCoder"); + } + IterableCoder iterableCoder = (IterableCoder) inCoder; + Coder elemCoder = iterableCoder.getElemCoder(); + + return in.apply(ParDo.of( + new DoFn, T>() { + @Override + public void processElement(ProcessContext c) { + for (T i : c.element()) { + c.output(i); + } + } + })) + .setCoder(elemCoder); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + FlattenPCollectionList.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + FlattenPCollectionList transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + FlattenPCollectionList transform, + DirectPipelineRunner.EvaluationContext context) { + List> outputElems = new ArrayList<>(); + PCollectionList inputs = transform.getInput(); + + for (PCollection input : inputs.getAll()) { + outputElems.addAll(context.getPCollectionValuesWithMetadata(input)); + } + + context.setPCollectionValuesWithMetadata(transform.getOutput(), outputElems); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java new file mode 100644 index 000000000000..d7a4de64e50d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java @@ -0,0 +1,517 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.ValueWithMetadata; +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindowingFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.NonMergingWindowingFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * {@code GroupByKey} takes a {@code PCollection>}, + * groups the values by key and windows, and returns a + * {@code PCollection>>} representing a map from + * each distinct key and window of the input {@code PCollection} to an + * {@code Iterable} over all the values associated with that key in + * the input. Each key in the output {@code PCollection} is unique within + * each window. + * + *

{@code GroupByKey} is analogous to converting a multi-map into + * a uni-map, and related to {@code GROUP BY} in SQL. It corresponds + * to the "shuffle" step between the Mapper and the Reducer in the + * MapReduce framework. + * + *

Two keys of type {@code K} are compared for equality + * not by regular Java {@link Object#equals}, but instead by + * first encoding each of the keys using the {@code Coder} of the + * keys of the input {@code PCollection}, and then comparing the + * encoded bytes. This admits efficient parallel evaluation. Note that + * this requires that the {@code Coder} of the keys be deterministic (see + * {@link Coder#isDeterministic()}). If the key {@code Coder} is not + * deterministic, an exception is thrown at runtime. + * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input, + * and the {@code Coder} of the elements of the {@code Iterable} + * values of the output {@code PCollection} is the same as the + * {@code Coder} of the values of the input. + * + *

Example of use: + *

 {@code
+ * PCollection> urlDocPairs = ...;
+ * PCollection>> urlToDocs =
+ *     urlDocPairs.apply(GroupByKey.create());
+ * PCollection results =
+ *     urlToDocs.apply(ParDo.of(new DoFn>, R>() {
+ *       public void processElement(ProcessContext c) {
+ *         String url = c.element().getKey();
+ *         Iterable docsWithThatUrl = c.element().getValue();
+ *         ... process all docs having that url ...
+ *       }}));
+ * } 
+ * + *

{@code GroupByKey} is a key primitive in data-parallel + * processing, since it is the main way to efficiently bring + * associated data together into one location. It is also a key + * determiner of the performance of a data-parallel pipeline. + * + *

See {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey} + * for a way to group multiple input PCollections by a common key at once. + * + *

See {@link Combine.PerKey} for a common pattern of + * {@code GroupByKey} followed by {@link Combine.GroupedValues}. + * + *

When grouping, windows that can be merged according to the {@link WindowingFn} + * of the input {@code PCollection} will be merged together, and a group + * corresponding to the new, merged window will be emitted. + * The timestamp for each group is the upper bound of its window, e.g., the most + * recent timestamp that can be assigned into the window, and the group will be + * in the window that it corresponds to. The output {@code PCollection} will + * have the same {@link WindowingFn} as the input. + * + *

If the {@link WindowingFn} of the input requires merging, it is not + * valid to apply another {@code GroupByKey} without first applying a new + * {@link WindowingFn}. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * and the elements of the {@code Iterable}s in the output + * {@code PCollection} + */ +public class GroupByKey + extends PTransform>, + PCollection>>> { + /** + * Returns a {@code GroupByKey} {@code PTransform}. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * and the elements of the {@code Iterable}s in the output + * {@code PCollection} + */ + public static GroupByKey create() { + return new GroupByKey<>(); + } + + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public PCollection>> apply(PCollection> input) { + return applyHelper(input, false, false); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that makes timestamps and window assignments + * explicit in the value part of each key/value pair. + */ + public static class ReifyTimestampsAndWindows + extends PTransform>, + PCollection>>> { + @Override + public PCollection>> apply( + PCollection> input) { + Coder> inputCoder = getInput().getCoder(); + KvCoder inputKvCoder = (KvCoder) inputCoder; + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder inputValueCoder = inputKvCoder.getValueCoder(); + Coder> outputValueCoder = FullWindowedValueCoder.of( + inputValueCoder, getInput().getWindowingFn().windowCoder()); + Coder>> outputKvCoder = + KvCoder.of(keyCoder, outputValueCoder); + return input.apply(ParDo.of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + KV kv = c.element(); + K key = kv.getKey(); + V value = kv.getValue(); + c.output(KV.of( + key, + WindowedValue.of(value, c.timestamp(), c.windows()))); + }})) + .setCoder(outputKvCoder); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that sorts the values associated with each key + * by timestamp. + */ + public static class SortValuesByTimestamp + extends PTransform>>>, + PCollection>>>> { + @Override + public PCollection>>> apply( + PCollection>>> input) { + return input.apply(ParDo.of( + new DoFn>>, + KV>>>() { + @Override + public void processElement(ProcessContext c) { + KV>> kvs = c.element(); + K key = kvs.getKey(); + Iterable> unsortedValues = kvs.getValue(); + List> sortedValues = new ArrayList<>(); + for (WindowedValue value : unsortedValues) { + sortedValues.add(value); + } + Collections.sort(sortedValues, + new Comparator>() { + @Override + public int compare(WindowedValue e1, WindowedValue e2) { + return e1.getTimestamp().compareTo(e2.getTimestamp()); + } + }); + c.output(KV.>>of(key, sortedValues)); + }})) + .setCoder(getInput().getCoder()); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that takes a collection of timestamp-ordered + * values associated with each key, groups the values by window, + * combines windows as needed, and for each window in each key, + * outputs a collection of key/value-list pairs implicitly assigned + * to the window and with the timestamp derived from that window. + */ + public static class GroupAlsoByWindow + extends PTransform>>>, + PCollection>>> { + private final WindowingFn windowingFn; + + public GroupAlsoByWindow(WindowingFn windowingFn) { + this.windowingFn = windowingFn; + } + + @Override + public PCollection>> apply( + PCollection>>> input) { + Coder>>> inputCoder = + getInput().getCoder(); + KvCoder>> inputKvCoder = + (KvCoder>>) inputCoder; + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder>> inputValueCoder = + inputKvCoder.getValueCoder(); + IterableCoder> inputIterableValueCoder = + (IterableCoder>) inputValueCoder; + Coder> inputIterableElementCoder = + inputIterableValueCoder.getElemCoder(); + WindowedValueCoder inputIterableWindowedValueCoder = + (WindowedValueCoder) inputIterableElementCoder; + Coder inputIterableElementValueCoder = + inputIterableWindowedValueCoder.getValueCoder(); + Coder> outputValueCoder = + IterableCoder.of(inputIterableElementValueCoder); + Coder>> outputKvCoder = + KvCoder.of(keyCoder, outputValueCoder); + + return input.apply(ParDo.of( + new GroupAlsoByWindowsDoFn( + windowingFn, inputIterableElementValueCoder))) + .setCoder(outputKvCoder); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Primitive helper transform that groups by key only, ignoring any + * window assignments. + */ + public static class GroupByKeyOnly + extends PTransform>, + PCollection>>> { + // TODO: Define and implement sorting by value. + boolean sortsValues = false; + + public GroupByKeyOnly() { } + + @Override + public PCollection>> apply(PCollection> input) { + WindowingFn windowingFn = getInput().getWindowingFn(); + if (!(windowingFn instanceof NonMergingWindowingFn)) { + // Prevent merging windows again, without explicit user + // involvement, e.g., by Window.into() or Window.remerge(). + windowingFn = new InvalidWindowingFn( + "WindowingFn has already been consumed by previous GroupByKey", + windowingFn); + } + return PCollection.>>createPrimitiveOutputInternal( + windowingFn); + } + + @Override + public void finishSpecifying() { + // Verify that the input Coder> is a KvCoder, and that + // the key coder is deterministic. + Coder keyCoder = getKeyCoder(); + if (!keyCoder.isDeterministic()) { + throw new IllegalStateException( + "the key Coder must be deterministic for grouping"); + } + if (getOutput().isOrdered()) { + throw new IllegalStateException( + "the result of a GroupByKey cannot be specified to be ordered"); + } + super.finishSpecifying(); + } + + /** + * Returns the {@code Coder} of the input to this transform, which + * should be a {@code KvCoder}. + */ + KvCoder getInputKvCoder() { + Coder> inputCoder = getInput().getCoder(); + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "GroupByKey requires its input to use KvCoder"); + } + return (KvCoder) inputCoder; + } + + /** + * Returns the {@code Coder} of the keys of the input to this + * transform, which is also used as the {@code Coder} of the keys of + * the output of this transform. + */ + Coder getKeyCoder() { + return getInputKvCoder().getKeyCoder(); + } + + /** + * Returns the {@code Coder} of the values of the input to this transform. + */ + Coder getInputValueCoder() { + return getInputKvCoder().getValueCoder(); + } + + /** + * Returns the {@code Coder} of the {@code Iterable} values of the + * output of this transform. + */ + Coder> getOutputValueCoder() { + return IterableCoder.of(getInputValueCoder()); + } + + /** + * Returns the {@code Coder} of the output of this transform. + */ + KvCoder> getOutputKvCoder() { + return KvCoder.of(getKeyCoder(), getOutputValueCoder()); + } + + @Override + protected Coder>> getDefaultOutputCoder() { + return getOutputKvCoder(); + } + + /** + * Returns whether this GBK sorts values. + */ + boolean sortsValues() { + return sortsValues; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + GroupByKeyOnly.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + GroupByKeyOnly transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + GroupByKeyOnly transform, + DirectPipelineRunner.EvaluationContext context) { + PCollection> input = transform.getInput(); + + List>> inputElems = + context.getPCollectionValuesWithMetadata(input); + + Coder keyCoder = transform.getKeyCoder(); + + Map, List> groupingMap = new HashMap<>(); + + for (ValueWithMetadata> elem : inputElems) { + K key = elem.getValue().getKey(); + V value = elem.getValue().getValue(); + Instant timestamp = elem.getTimestamp(); + byte[] encodedKey; + try { + encodedKey = encodeToByteArray(keyCoder, key); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + "unable to encode key " + key + " of input to " + transform + + " using " + keyCoder, + exn); + } + GroupingKey groupingKey = new GroupingKey<>(key, encodedKey); + List values = groupingMap.get(groupingKey); + if (values == null) { + values = new ArrayList(); + groupingMap.put(groupingKey, values); + } + values.add(value); + } + + List>>> outputElems = + new ArrayList<>(); + for (Map.Entry, List> entry : groupingMap.entrySet()) { + GroupingKey groupingKey = entry.getKey(); + K key = groupingKey.getKey(); + List values = entry.getValue(); + values = context.randomizeIfUnordered( + transform.sortsValues(), values, true /* inPlaceAllowed */); + outputElems.add(ValueWithMetadata + .of(WindowedValue.valueInEmptyWindows(KV.>of(key, values))) + .withKey(key)); + } + + context.setPCollectionValuesWithMetadata(transform.getOutput(), + outputElems); + } + + public PCollection>> applyHelper( + PCollection> input, boolean isStreaming, boolean runnerSortsByTimestamp) { + Coder> inputCoder = getInput().getCoder(); + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "GroupByKey requires its input to use KvCoder"); + } + // This operation groups by the combination of key and window, + // merging windows as needed, using the windows assigned to the + // key/value input elements and the window merge operation of the + // windowing function associated with the input PCollection. + WindowingFn windowingFn = getInput().getWindowingFn(); + if (windowingFn instanceof InvalidWindowingFn) { + String cause = ((InvalidWindowingFn) windowingFn).getCause(); + throw new IllegalStateException( + "GroupByKey must have a valid Window merge function. " + + "Invalid because: " + cause); + } + if (windowingFn.isCompatible(new GlobalWindow())) { + // The input PCollection is using the degenerate default + // windowing function, which uses a single global window for all + // elements. We can implement this using a more-primitive + // non-window-aware GBK transform. + return input.apply(new GroupByKeyOnly()); + + } else if (isStreaming) { + // If using the streaming runner, the service will do the insertion of + // the GroupAlsoByWindow step. + // TODO: Remove this case once the Dataflow Runner handles GBK directly + return input.apply(new GroupByKeyOnly()); + + } else { + // By default, implement GroupByKey[AndWindow] via a series of lower-level + // operations. + PCollection>>> gbkOutput = input + // Make each input element's timestamp and assigned windows + // explicit, in the value part. + .apply(new ReifyTimestampsAndWindows()) + + // Group by just the key. + .apply(new GroupByKeyOnly>()); + + if (!runnerSortsByTimestamp) { + // Sort each key's values by timestamp. GroupAlsoByWindow requires + // its input to be sorted by timestamp. + gbkOutput = gbkOutput.apply(new SortValuesByTimestamp()); + } + + return gbkOutput + // Group each key's values by window, merging windows as needed. + .apply(new GroupAlsoByWindow(windowingFn)); + } + } + + private static class GroupingKey { + private K key; + private byte[] encodedKey; + + public GroupingKey(K key, byte[] encodedKey) { + this.key = key; + this.encodedKey = encodedKey; + } + + public K getKey() { return key; } + + @Override + public boolean equals(Object o) { + if (o instanceof GroupingKey) { + GroupingKey that = (GroupingKey) o; + return Arrays.equals(this.encodedKey, that.encodedKey); + } else { + return false; + } + } + + @Override + public int hashCode() { return Arrays.hashCode(encodedKey); } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java new file mode 100644 index 000000000000..08a801b15ec2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code Keys} takes a {@code PCollection} of {@code KV}s and + * returns a {@code PCollection} of the keys. + * + *

Example of use: + *

 {@code
+ * PCollection> wordCounts = ...;
+ * PCollection words = wordCounts.apply(Keys.create());
+ * } 
+ * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See also {@link Values}. + * + * @param the type of the keys in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ +public class Keys extends PTransform>, + PCollection> { + /** + * Returns a {@code Keys} {@code PTransform}. + * + * @param the type of the keys in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ + public static Keys create() { + return new Keys<>(); + } + + private Keys() { } + + @Override + public PCollection apply(PCollection> in) { + return + in.apply(ParDo.named("Keys") + .of(new DoFn, K>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey()); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java new file mode 100644 index 000000000000..ee73ae4087f5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code KvSwap} takes a {@code PCollection>} and + * returns a {@code PCollection>}, where all the keys and + * values have been swapped. + * + *

Example of use: + *

 {@code
+ * PCollection wordsToCounts = ...;
+ * PCollection countsToWords =
+ *     wordToCounts.apply(KvSwap.create());
+ * } 
+ * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + * @param the type of the keys in the input {@code PCollection} + * and the values in the output {@code PCollection} + * @param the type of the values in the input {@code PCollection} + * and the keys in the output {@code PCollection} + */ +public class KvSwap extends PTransform>, + PCollection>> { + /** + * Returns a {@code KvSwap} {@code PTransform}. + * + * @param the type of the keys in the input {@code PCollection} + * and the values in the output {@code PCollection} + * @param the type of the values in the input {@code PCollection} + * and the keys in the output {@code PCollection} + */ + public static KvSwap create() { + return new KvSwap<>(); + } + + private KvSwap() { } + + @Override + public PCollection> apply(PCollection> in) { + return + in.apply(ParDo.named("KvSwap") + .of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + c.output(KV.of(e.getValue(), e.getKey())); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java new file mode 100644 index 000000000000..fce9a328f1c8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +/** + * {@code PTransform}s for computing the maximum of the elements in a + * {@code PCollection}, or the maximum of the values associated with + * each key in a {@code PCollection} of {@code KV}s. + * + *

Example 1: get the maximum of a {@code PCollection} of {@code Double}s. + *

 {@code
+ * PCollection input = ...;
+ * PCollection max = input.apply(Max.doublesGlobally());
+ * } 
+ * + *

Example 2: calculate the maximum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

 {@code
+ * PCollection> input = ...;
+ * PCollection> maxPerKey = input
+ *     .apply(Max.integersPerKey());
+ * } 
+ */ +public class Max { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the maximum of the + * input {@code PCollection}'s elements, or + * {@code Integer.MIN_VALUE} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + Combine.Globally combine = Combine + .globally(new MaxIntegerFn()); + combine.setName("Max"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the maximum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey integersPerKey() { + Combine.PerKey combine = Combine + .perKey(new MaxIntegerFn()); + combine.setName("Max.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the maximum of the + * input {@code PCollection}'s elements, or + * {@code Long.MIN_VALUE} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + Combine.Globally combine = Combine.globally(new MaxLongFn()); + combine.setName("Max"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the maximum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey longsPerKey() { + Combine.PerKey combine = Combine + .perKey(new MaxLongFn()); + combine.setName("Max.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the maximum of the + * input {@code PCollection}'s elements, or + * {@code Double.MIN_VALUE} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + Combine.Globally combine = Combine + .globally(new MaxDoubleFn()); + combine.setName("Max"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the maximum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey doublesPerKey() { + Combine.PerKey combine = Combine + .perKey(new MaxDoubleFn()); + combine.setName("Max.PerKey"); + return combine; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code SerializableFunction} that computes the maximum of an + * {@code Iterable} of numbers of type {@code N}, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + * + * @param the type of the {@code Number}s being compared + */ + public static class MaxFn> + implements SerializableFunction, N> { + + /** The smallest value of type N. */ + private final N initialValue; + + /** + * Constructs a combining function that computes the maximum over + * a collection of values of type {@code N}, given the smallest + * value of type {@code N}, which is the identity value for the + * maximum operation over {@code N}s. + */ + public MaxFn(N initialValue) { + this.initialValue = initialValue; + } + + @Override + public N apply(Iterable input) { + N max = initialValue; + for (N value : input) { + if (value.compareTo(max) > 0) { + max = value; + } + } + return max; + } + } + + /** + * A {@code SerializableFunction} that computes the maximum of an + * {@code Iterable} of {@code Integer}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxIntegerFn extends MaxFn { + public MaxIntegerFn() { super(Integer.MIN_VALUE); } + } + + /** + * A {@code SerializableFunction} that computes the maximum of an + * {@code Iterable} of {@code Long}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxLongFn extends MaxFn { + public MaxLongFn() { super(Long.MIN_VALUE); } + } + + /** + * A {@code SerializableFunction} that computes the maximum of an + * {@code Iterable} of {@code Double}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxDoubleFn extends MaxFn { + public MaxDoubleFn() { super(Double.MIN_VALUE); } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java new file mode 100644 index 000000000000..34fbb1fc2908 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; + +/** + * {@code PTransform}s for computing the arithmetic mean + * (a.k.a. average) of the elements in a {@code PCollection}, or the + * mean of the values associated with each key in a + * {@code PCollection} of {@code KV}s. + * + *

Example 1: get the mean of a {@code PCollection} of {@code Long}s. + *

 {@code
+ * PCollection input = ...;
+ * PCollection mean = input.apply(Mean.globally());
+ * } 
+ * + *

Example 2: calculate the mean of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

 {@code
+ * PCollection> input = ...;
+ * PCollection> meanPerKey =
+ *     input.apply(Mean.perKey());
+ * } 
+ */ +public class Mean { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the mean of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + * + * @param the type of the {@code Number}s being combined + */ + public static Combine.Globally globally() { + Combine.Globally combine = Combine.globally(new MeanFn<>()); + combine.setName("Mean"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the mean of the values associated with + * that key in the input {@code PCollection}. + * + * See {@link Combine.PerKey} for how this affects timestamps and bucketing. + * + * @param the type of the keys + * @param the type of the {@code Number}s being combined + */ + public static Combine.PerKey perKey() { + Combine.PerKey combine = Combine.perKey(new MeanFn<>()); + combine.setName("Mean.PerKey"); + return combine; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code Combine.CombineFn} that computes the arithmetic mean + * (a.k.a. average) of an {@code Iterable} of numbers of type + * {@code N}, useful as an argument to {@link Combine#globally} or + * {@link Combine#perKey}. + * + *

Returns {@code 0} if combining zero elements. + * + * @param the type of the {@code Number}s being combined + */ + public static class MeanFn extends + Combine.AccumulatingCombineFn.CountSum, Double> { + + /** + * Constructs a combining function that computes the mean over + * a collection of values of type {@code N}. + */ + public MeanFn() {} + + /** + * Accumulator helper class for MeanFn. + */ + class CountSum + extends Combine.AccumulatingCombineFn.Accumulator { + + long count = 0; + double sum = 0.0; + + @Override + public void addInput(N element) { + count++; + sum += element.doubleValue(); + } + + @Override + public void mergeAccumulator(CountSum accumulator) { + count += accumulator.count; + sum += accumulator.sum; + } + + @Override + public Double extractOutput() { + return count == 0 ? 0.0 : sum / count; + } + } + + @Override + public CountSum createAccumulator() { + return new CountSum(); + } + + @SuppressWarnings("unchecked") + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + // The casts are needed because CountSum.class is a + // Class, but we need a + // Class.CountSum>. + return SerializableCoder.of((Class) (Class) CountSum.class); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java new file mode 100644 index 000000000000..337a05116097 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +/** + * {@code PTransform}s for computing the minimum of the elements in a + * {@code PCollection}, or the minimum of the values associated with + * each key in a {@code PCollection} of {@code KV}s. + * + *

Example 1: get the minimum of a {@code PCollection} of {@code Double}s. + *

 {@code
+ * PCollection input = ...;
+ * PCollection min = input.apply(Min.doublesGlobally());
+ * } 
+ * + *

Example 2: calculate the minimum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

 {@code
+ * PCollection> input = ...;
+ * PCollection> minPerKey = input
+ *     .apply(Min.integersPerKey());
+ * } 
+ */ +public class Min { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is a single value that is + * the minimum of the input {@code PCollection}'s elements, or + * {@code Integer.MAX_VALUE} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + Combine.Globally combine = Combine + .globally(new MinIntegerFn()); + combine.setName("Min"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the minimum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey integersPerKey() { + Combine.PerKey combine = Combine + .perKey(new MinIntegerFn()); + combine.setName("Min.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the minimum of the + * input {@code PCollection}'s elements, or + * {@code Long.MAX_VALUE} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + Combine.Globally combine = Combine.globally(new MinLongFn()); + combine.setName("Min"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the minimum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey longsPerKey() { + Combine.PerKey combine = Combine + .perKey(new MinLongFn()); + combine.setName("Min.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the minimum of the + * input {@code PCollection}'s elements, or + * {@code Double.MAX_VALUE} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + Combine.Globally combine = Combine + .globally(new MinDoubleFn()); + combine.setName("Min"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the minimum of the values associated with + * that key in the input {@code PCollection}. + * + *

See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey doublesPerKey() { + Combine.PerKey combine = Combine + .perKey(new MinDoubleFn()); + combine.setName("Min.PerKey"); + return combine; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code SerializableFunction} that computes the minimum of an + * {@code Iterable} of numbers of type {@code N}, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + * + * @param the type of the {@code Number}s being compared + */ + public static class MinFn> + implements SerializableFunction, N> { + + /** The largest value of type N. */ + private final N initialValue; + + /** + * Constructs a combining function that computes the minimum over + * a collection of values of type {@code N}, given the largest + * value of type {@code N}, which is the identity value for the + * minimum operation over {@code N}s. + */ + public MinFn(N initialValue) { + this.initialValue = initialValue; + } + + @Override + public N apply(Iterable input) { + N min = initialValue; + for (N value : input) { + if (value.compareTo(min) < 0) { + min = value; + } + } + return min; + } + } + + /** + * A {@code SerializableFunction} that computes the minimum of an + * {@code Iterable} of {@code Integer}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinIntegerFn extends MinFn { + public MinIntegerFn() { super(Integer.MAX_VALUE); } + } + + /** + * A {@code SerializableFunction} that computes the minimum of an + * {@code Iterable} of {@code Long}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinLongFn extends MinFn { + public MinLongFn() { super(Long.MAX_VALUE); } + } + + /** + * A {@code SerializableFunction} that computes the minimum of an + * {@code Iterable} of {@code Double}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinDoubleFn extends MinFn { + public MinDoubleFn() { super(Double.MAX_VALUE); } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java new file mode 100644 index 000000000000..5906d7212dba --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java @@ -0,0 +1,400 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.TypedPValue; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * A {@code PTransform} is an operation that takes an + * {@code Input} (some subtype of {@link PInput}) and produces an + * {@code Output} (some subtype of {@link POutput}). + * + *

Common PTransforms include root PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read}, + * {@link Create}, processing and + * conversion operations like {@link ParDo}, + * {@link GroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey}, + * {@link Combine}, and {@link Count}, and outputting + * PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write}. Users also + * define their own application-specific composite PTransforms. + * + *

Each {@code PTransform} has a single + * {@code Input} type and a single {@code Output} type. Many + * PTransforms conceptually transform one input value to one output + * value, and in this case {@code Input} and {@code Output} are + * typically instances of + * {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * A root + * PTransform conceptually has no input; in this case, conventionally + * a {@link com.google.cloud.dataflow.sdk.values.PBegin} object + * produced by calling {@link Pipeline#begin} is used as the input. + * An outputting PTransform conceptually has no output; in this case, + * conventionally {@link com.google.cloud.dataflow.sdk.values.PDone} + * is used as its output type. Some PTransforms conceptually have + * multiple inputs and/or outputs; in these cases special "bundling" + * classes like + * {@link com.google.cloud.dataflow.sdk.values.PCollectionList}, + * {@link com.google.cloud.dataflow.sdk.values.PCollectionTuple} + * are used + * to combine multiple values into a single bundle for passing into or + * returning from the PTransform. + * + *

A {@code PTransform} is invoked by calling + * {@code apply()} on its {@code Input}, returning its {@code Output}. + * Calls can be chained to concisely create linear pipeline segments. + * For example: + * + *

 {@code
+ * PCollection pc1 = ...;
+ * PCollection pc2 =
+ *     pc1.apply(ParDo.of(new MyDoFn>()))
+ *        .apply(GroupByKey.create())
+ *        .apply(Combine.perKey(new MyKeyedCombineFn()))
+ *        .apply(ParDo.of(new MyDoFn2,T2>()));
+ * } 
+ * + *

PTransform operations have unique names, which are used by the + * system when explaining what's going on during optimization and + * execution. Each PTransform gets a system-provided default name, + * but it's a good practice to specify an explicit name, where + * possible, using the {@code named()} method offered by some + * PTransforms such as {@link ParDo}. For example: + * + *

 {@code
+ * ...
+ * .apply(ParDo.named("Step1").of(new MyDoFn3()))
+ * ...
+ * } 
+ * + *

Each PCollection output produced by a PTransform, + * either directly or within a "bundling" class, automatically gets + * its own name derived from the name of its producing PTransform. An + * output's name can be changed by invoking + * {@link com.google.cloud.dataflow.sdk.values.PValue#setName}. + * + *

Each PCollection output produced by a PTransform + * also records a {@link com.google.cloud.dataflow.sdk.coders.Coder} + * that specifies how the elements of that PCollection + * are to be encoded as a byte string, if necessary. The + * PTransform may provide a default Coder for any of its outputs, for + * instance by deriving it from the PTransform input's Coder. If the + * PTransform does not specify the Coder for an output PCollection, + * the system will attempt to infer a Coder for it, based on + * what's known at run-time about the Java type of the output's + * elements. The enclosing {@link Pipeline}'s + * {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry} + * (accessible via {@link Pipeline#getCoderRegistry}) defines the + * mapping from Java types to the default Coder to use, for a standard + * set of Java types; users can extend this mapping for additional + * types, via + * {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry#registerCoder}. + * If this inference process fails, either because the Java type was + * not known at run-time (e.g., due to Java's "erasure" of generic + * types) or there was no default Coder registered, then the Coder + * should be specified manually by calling + * {@link com.google.cloud.dataflow.sdk.values.TypedPValue#setCoder} + * on the output PCollection. The Coder of every output + * PCollection must be determined one way or another + * before that output is used as an input to another PTransform, or + * before the enclosing Pipeline is run. + * + *

A small number of PTransforms are implemented natively by the + * Google Cloud Dataflow SDK; such PTransforms simply return an + * output value as their apply implementation. + * The majority of PTransforms are + * implemented as composites of other PTransforms. Such a PTransform + * subclass typically just implements {@link #apply}, computing its + * Output value from its Input value. User programs are encouraged to + * use this mechanism to modularize their own code. Such composite + * abstractions get their own name, and navigating through the + * composition hierarchy of PTransforms is supported by the monitoring + * interface. Examples of composite PTransforms can be found in this + * directory and in examples. From the caller's point of view, there + * is no distinction between a PTransform implemented natively and one + * implemented in terms of other PTransforms; both kinds of PTransform + * are invoked in the same way, using {@code apply()}. + * + *

Note on Serialization

+ * + * {@code PTransform} doesn't actually support serialization, despite + * implementing {@code Serializable}. + * + *

{@code PTransform} is marked {@code Serializable} solely + * because it is common for an anonymous {@code DoFn}, + * instance to be created within an + * {@code apply()} method of a composite {@code PTransform}. + * + *

Each of those {@code *Fn}s is {@code Serializable}, but + * unfortunately its instance state will contain a reference to the + * enclosing {@code PTransform} instance, and so attempt to serialize + * the {@code PTransform} instance, even though the {@code *Fn} + * instance never references anything about the enclosing + * {@code PTransform}. + * + *

Composite transforms, which are defined in terms of other transforms, + * should return the output of one of the composed transforms. Non-composite + * transforms, which do not apply any transforms internally, should return + * a new unbound output and register evaluators (via backend-specific + * registration methods). + * + *

The default implementation throws an exception. A derived class must + * either implement apply, or else each runner must supply a custom + * implementation via + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner#apply}. + */ + public Output apply(Input input) { + throw new IllegalArgumentException( + "Runner " + getPipeline().getRunner() + + " has not registered an implementation for the required primitive operation " + + this); + } + + /** + * Sets the base name of this {@code PTransform}. + */ + public void setName(String name) { + this.name = name; + } + + /** + * Sets the base name of this {@code PTransform} and returns itself. + * + *

This is a shortcut for calling {@link #setName}, which allows method + * chaining. + */ + public PTransform withName(String name) { + setName(name); + return this; + } + + /** + * Returns the transform name. + * + *

This name is provided by the transform creator and is not required to be unique. + */ + public String getName() { + return name != null ? name : getDefaultName(); + } + + /** + * Returns the owning {@link Pipeline} of this {@code PTransform}. + * + * @throws IllegalStateException if the owning {@code Pipeline} hasn't been + * set yet + */ + @Deprecated + public Pipeline getPipeline() { + if (pipeline == null) { + throw new IllegalStateException("owning pipeline not set"); + } + return pipeline; + } + + /** + * Returns the input of this transform. + * + * @throws IllegalStateException if this PTransform hasn't been applied yet + * @deprecated Use pipeline.getInput(transform) + */ + @Deprecated + public Input getInput() { + @SuppressWarnings("unchecked") + Input input = (Input) getPipeline().getInput(this); + return input; + } + + /** + * Returns the output of this transform. + * + * @throws IllegalStateException if this PTransform hasn't been applied yet + * #deprecated use pipeline.getOutput(transform) + */ + @Deprecated + public Output getOutput() { + @SuppressWarnings("unchecked") + Output output = (Output) getPipeline().getOutput(this); + return output; + } + + /** + * Returns the {@link CoderRegistry}, useful for inferring + * {@link com.google.cloud.dataflow.sdk.coders.Coder}s. + * + * @throws IllegalStateException if the owning {@link Pipeline} hasn't been + * set yet + * @deprecated use pipeline.getCoderRegistry() + */ + @Deprecated + protected CoderRegistry getCoderRegistry() { + return getPipeline().getCoderRegistry(); + } + + + ///////////////////////////////////////////////////////////////////////////// + + // See the note about about PTransform's fake Serializability, to + // understand why all of its instance state is transient. + + /** + * The base name of this {@code PTransform}, e.g., from + * {@link ParDo#named(String)}, or from defaults, or {@code null} if not + * yet assigned. + */ + protected transient String name; + + /** + * The {@link Pipeline} that owns this {@code PTransform}, or {@code null} + * if not yet set. + */ + private transient Pipeline pipeline; + + protected PTransform() { + this.name = null; + } + + protected PTransform(String name) { + this.name = name; + } + + /** + * Associates this {@code PTransform} with the given {@code Pipeline}. + * + *

For internal use only. + * + * @throws IllegalArgumentException if this transform has already + * been associated with a pipeline + */ + @Deprecated + public void setPipeline(Pipeline pipeline) { + if (this.pipeline != null) { + throw new IllegalStateException( + "internal error: transform already initialized"); + } + this.pipeline = pipeline; + } + + @Override + public String toString() { + return getName() + " [" + getKindString() + "]"; + } + + /** + * Returns the name to use by default for this {@code PTransform} + * (not including the names of any enclosing {@code PTransform}s). + * + *

By default, returns {@link #getKindString}. + * + *

The caller is responsible for ensuring that names of applied + * {@code PTransform}s are unique, e.g., by adding a uniquifying + * suffix when needed. + */ + protected String getDefaultName() { + return getKindString(); + } + + /** + * Returns a string describing what kind of {@code PTransform} this is. + * + *

By default, returns the base name of this + * {@code PTransform}'s class. + */ + protected String getKindString() { + return StringUtils.approximateSimpleName(getClass()); + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + // We don't really want to be serializing this object, but we + // often have serializable anonymous DoFns nested within a + // PTransform. + } + + private void readObject(ObjectInputStream oos) + throws IOException, ClassNotFoundException { + // We don't really want to be serializing this object, but we + // often have serializable anonymous DoFns nested within a + // PTransform. + } + + /** + * After building, finalizes this {@code PTransform} to + * make it ready for running. Called automatically when its + * output(s) are finished. + * + *

Not normally called by user code. + */ + public void finishSpecifying() { + getOutput().finishSpecifyingOutput(); + } + + /** + * Returns the default {@code Coder} to use for the output of this + * single-output {@code PTransform}, or {@code null} if + * none can be inferred. + * + *

By default, returns {@code null}. + */ + protected Coder getDefaultOutputCoder() { + return null; + } + + /** + * Returns the default {@code Coder} to use for the given output of + * this single-output {@code PTransform}, or {@code null} + * if none can be inferred. + */ + public Coder getDefaultOutputCoder(TypedPValue output) { + if (output != getOutput()) { + return null; + } else { + @SuppressWarnings("unchecked") + Coder defaultOutputCoder = (Coder) getDefaultOutputCoder(); + return defaultOutputCoder; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java new file mode 100644 index 000000000000..c7d925b2b418 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java @@ -0,0 +1,1054 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * {@code ParDo} is the core element-wise transform in Google Cloud + * Dataflow, invoking a user-specified function (from {@code I} to + * {@code O}) on each of the elements of the input + * {@code PCollection} to produce zero or more output elements, all + * of which are collected into the output {@code PCollection}. + * + *

Elements are processed independently, and possibly in parallel across + * distributed cloud resources. + * + *

The {@code ParDo} processing style is similar to what happens inside + * the "Mapper" or "Reducer" class of a MapReduce-style algorithm. + * + *

{@code DoFn}s

+ * + *

The function to use to process each element is specified by a + * {@link DoFn DoFn}. + * + *

Conceptually, when a {@code ParDo} transform is executed, the + * elements of the input {@code PCollection} are first divided up + * into some number of "batches". These are farmed off to distributed + * worker machines (or run locally, if using the + * {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner}). + * For each batch of input elements, a fresh instance of the argument + * {@code DoFn} is created on a worker, then the {@code DoFn}'s + * optional {@link DoFn#startBundle} method is called to initialize it, + * then the {@code DoFn}'s required {@link DoFn#processElement} method + * is called on each of the input elements in the batch, then the + * {@code DoFn}'s optional {@link DoFn#finishBundle} method is called + * to complete its work, and finally the {@code DoFn} instance is + * thrown away. Each of the calls to any of the {@code DoFn}'s + * methods can produce zero or more output elements, which are + * collected together into a batch of output elements. All of the + * batches of output elements from all of the {@code DoFn} instances + * are "flattened" together into the output {@code PCollection}. + * + *

For example: + * + *

 {@code
+ * PCollection lines = ...;
+ * PCollection words =
+ *     lines.apply(ParDo.of(new DoFn() {
+ *         public void processElement(ProcessContext c) {
+ *           String line = c.element();
+ *           for (String word : line.split("[^a-zA-Z']+")) {
+ *             c.output(word);
+ *           }
+ *         }}));
+ * PCollection wordLengths =
+ *     words.apply(ParDo.of(new DoFn() {
+ *         public void processElement(ProcessContext c) {
+ *           String word = c.element();
+ *           Integer length = word.length();
+ *           c.output(length);
+ *         }}));
+ * } 
+ * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

Naming {@code ParDo}s

+ * + *

A {@code ParDo} transform can be given a name using + * {@link #named}. While the system will automatically provide a name + * if none is specified explicitly, it is still a good practice to + * provide an explicit name, since that will probably make monitoring + * output more readable. For example: + * + *

 {@code
+ * PCollection words =
+ *     lines.apply(ParDo.named("ExtractWords")
+ *                      .of(new DoFn() { ... }));
+ * PCollection wordLengths =
+ *     words.apply(ParDo.named("ComputeWordLengths")
+ *                      .of(new DoFn() { ... }));
+ * } 
+ * + *

Side Inputs

+ * + *

While a {@code ParDo} iterates over a single "main input" + * {@code PCollection}, it can take additional "side input" + * {@code PCollectionView}s. These side input + * {@code PCollectionView}s express styles of accessing + * {@code PCollection}s computed by earlier pipeline operations, + * passed in to the {@code ParDo} transform using + * {@link #withSideInputs}, and their contents accessible to each of + * the {@code DoFn} operations via {@link DoFn.Context#sideInput}. + * For example: + * + *

 {@code
+ * PCollection words = ...;
+ * PCollection maxWordLengthCutOff = ...; // Singleton PCollection
+ * final PCollectionView maxWordLengthCutOffView =
+ *     SingletonPCollectionView.of(maxWordLengthCutOff);
+ * PCollection wordsBelowCutOff =
+ *     words.apply(ParDo.withSideInput(maxWordLengthCutOffView)
+ *                      .of(new DoFn() {
+ *         public void processElement(ProcessContext c) {
+ *           String word = c.element();
+ *           int lengthCutOff = c.sideInput(maxWordLengthCutOffView);
+ *           if (word.length() <= lengthCutOff) {
+ *             c.output(word);
+ *           }
+ *         }}));
+ * } 
+ * + *

Side Outputs

+ * + *

Optionally, a {@code ParDo} transform can produce multiple + * output {@code PCollection}s, both a "main output" + * {@code PCollection} plus any number of "side output" + * {@code PCollection}s, each keyed by a distinct {@link TupleTag}, + * and bundled in a {@link PCollectionTuple}. The {@code TupleTag}s + * to be used for the output {@code PCollectionTuple} is specified by + * invoking {@link #withOutputTags}. Unconsumed side outputs does not + * necessarily need to be explicity specified, even if the {@code DoFn} + * generates them. Within the {@code DoFn}, an element is added to the + * main output {@code PCollection} as normal, using + * {@link DoFn.Context#output}, while an element is added to a side output + * {@code PCollection} using {@link DoFn.Context#sideOutput}. For example: + * + *

 {@code
+ * PCollection words = ...;
+ * // Select words whose length is below a cut off,
+ * // plus the lengths of words that are above the cut off.
+ * // Also select words starting with "MARKER".
+ * final int wordLengthCutOff = 10;
+ * // Create tags to use for the main and side outputs.
+ * final TupleTag wordsBelowCutOffTag =
+ *     new TupleTag(){};
+ * final TupleTag wordLengthsAboveCutOffTag =
+ *     new TupleTag(){};
+ * final TupleTag markedWordsTag =
+ *     new TupleTag(){};
+ * PCollectionTuple results =
+ *     words.apply(
+ *         ParDo
+ *         // Specify the main and consumed side output tags of the
+ *         // PCollectionTuple result:
+ *         .withOutputTags(wordsBelowCutOffTag,
+ *                         TupleTagList.of(wordLengthsAboveCutOffTag)
+ *                                     .and(markedWordsTag))
+ *         .of(new DoFn() {
+ *             // Create a tag for the unconsumed side output.
+ *             final TupleTag specialWordsTag =
+ *                 new TupleTag(){};
+ *             public void processElement(ProcessContext c) {
+ *               String word = c.element();
+ *               if (word.length() <= wordLengthCutOff) {
+ *                 // Emit this short word to the main output.
+ *                 c.output(word);
+ *               } else {
+ *                 // Emit this long word's length to a side output.
+ *                 c.sideOutput(wordLengthsAboveCutOffTag, word.length());
+ *               }
+ *               if (word.startsWith("MARKER")) {
+ *                 // Emit this word to a different side output.
+ *                 c.sideOutput(markedWordsTag, word);
+ *               }
+ *               if (word.startsWith("SPECIAL")) {
+ *                 // Emit this word to the unconsumed side output.
+ *                 c.sideOutput(specialWordsTag, word);
+ *               }
+ *             }}));
+ * // Extract the PCollection results, by tag.
+ * PCollection wordsBelowCutOff =
+ *     results.get(wordsBelowCutOffTag);
+ * PCollection wordLengthsAboveCutOff =
+ *     results.get(wordLengthsAboveCutOffTag);
+ * PCollection markedWords =
+ *     results.get(markedWordsTag);
+ * } 
+ * + *

Properties May Be Specified In Any Order

+ * + * Several properties can be specified for a {@code ParDo} + * {@code PTransform}, including name, side inputs, side output tags, + * and {@code DoFn} to invoke. Only the {@code DoFn} is required; the + * name is encouraged but not required, and side inputs and side + * output tags are only specified when they're needed. These + * properties can be specified in any order, as long as they're + * specified before the {@code ParDo} {@code PTransform} is applied. + * + *

The approach used to allow these properties to be specified in + * any order, with some properties omitted, is to have each of the + * property "setter" methods defined as static factory methods on + * {@code ParDo} itself, which return an instance of either + * {@link ParDo.Unbound ParDo.Unbound} or + * {@link ParDo.Bound ParDo.Bound} nested classes, each of which offer + * property setter instance methods to enable setting additional + * properties. {@code ParDo.Bound} is used for {@code ParDo} + * transforms whose {@code DoFn} is specified and whose input and + * output static types have been bound. {@code ParDo.Unbound} is used + * for {@code ParDo} transforms that have not yet had their + * {@code DoFn} specified. Only {@code ParDo.Bound} instances can be + * applied. + * + *

Another benefit of this approach is that it reduces the number + * of type parameters that need to be specified manually. In + * particular, the input and output types of the {@code ParDo} + * {@code PTransform} are inferred automatically from the type + * parameters of the {@code DoFn} argument passed to {@link ParDo#of}. + * + *

Output Coders

+ * + *

By default, the {@code Coder} of the + * elements of the main output {@code PCollection} is inferred from the + * concrete type of the {@code DoFn}'s output type {@code O}. + * + *

By default, the {@code Coder} of the elements of a side output + * {@code PCollection} is inferred from the concrete type of the + * corresponding {@code TupleTag}'s type {@code X}. To be + * successful, the {@code TupleTag} should be created as an instance + * of a trivial anonymous subclass, with {@code {}} suffixed to the + * constructor call. Such uses block Java's generic type parameter + * inference, so the {@code } argument must be provided explicitly. + * For example: + *

 {@code
+ * // A TupleTag to use for a side input can be written concisely:
+ * final TupleTag sideInputTag = new TupleTag<>();
+ * // A TupleTag to use for a side output should be written with "{}",
+ * // and explicit generic parameter type:
+ * final TupleTag sideOutputTag = new TupleTag(){};
+ * } 
+ * This style of {@code TupleTag} instantiation is used in the example of + * multiple side outputs, above. + * + *

Ordered Input and/or Output PCollections

+ * + *

If the input {@code PCollection} is ordered (see + * {@link PCollection#setOrdered}), then each batch of the input + * processed by a {@code DoFn} instance will correspond to a + * consecutive subsequence of elements of the input, and the + * {@link DoFn#processElement} operation will be invoked on each + * element of the batch in order; otherwise, batches will correspond + * to arbitrary subsets of elements of the input, processed in + * arbitrary order. + * + *

Independently, if a main or side output {@code PCollection} is + * ordered, then the order in which elements are output to it will be + * preserved in the output {@code PCollection}; otherwise, the order + * in which elements are output to the {@code PCollection} doesn't + * matter. If the input {@code PCollection} is also ordered, then the + * sequences of elements output from the batches will be concatenated + * together in the same order as the batches appear in the input, + * supporting order-preserving transforms on {@code PCollection}s. + * + *

Serializability of {@code DoFn}s

+ * + *

A {@code DoFn} passed to a {@code ParDo} transform must be + * {@code Serializable}. This allows the {@code DoFn} instance + * created in this "main program" to be sent (in serialized form) to + * remote worker machines and reconstituted for each batch of elements + * of the input {@code PCollection} being processed. A {@code DoFn} + * can have instance variable state, and non-transient instance + * variable state will be serialized in the main program and then + * deserialized on remote worker machines for each batch of elements + * to process. + * + *

To aid in ensuring that {@code DoFn}s are properly + * {@code Serializable}, even local execution using the + * {@link DirectPipelineRunner} will serialize and then deserialize + * {@code DoFn}s before executing them on a batch. + * + *

{@code DoFn}s expressed as anonymous inner classes can be + * convenient, but due to a quirk in Java's rules for serializability, + * non-static inner or nested classes (including anonymous inner + * classes) automatically capture their enclosing class's instance in + * their serialized state. This can lead to including much more than + * intended in the serialized state of a {@code DoFn}, or even things + * that aren't {@code Serializable}. + * + *

There are two ways to avoid unintended serialized state in a + * {@code DoFn}: + * + *

    + * + *
  • Define the {@code DoFn} as a named, static class. + * + *
  • Define the {@code DoFn} as an anonymous inner class inside of + * a static method. + * + *
+ * + * Both these approaches ensure that there is no implicit enclosing + * class instance serialized along with the {@code DoFn} instance. + * + *

Prior to Java 8, any local variables of the enclosing + * method referenced from within an anonymous inner class need to be + * marked as {@code final}. If defining the {@code DoFn} as a named + * static class, such variables would be passed as explicit + * constructor arguments and stored in explicit instance variables. + * + *

There are three main ways to initialize the state of a + * {@code DoFn} instance processing a batch: + * + *

    + * + *
  • Define instance variable state (including implicit instance + * variables holding final variables captured by an anonymous inner + * class), initialized by the {@code DoFn}'s constructor (which is + * implicit for an anonymous inner class). This state will be + * automatically serialized and then deserialized in the {@code DoFn} + * instance created for each batch. This method is good for state + * known when the original {@code DoFn} is created in the main + * program, if it's not overly large. + * + *
  • Compute the state as a singleton {@code PCollection} and pass it + * in as a side input to the {@code DoFn}. This is good if the state + * needs to be computed by the pipeline, or if the state is very large + * and so is best read from file(s) rather than sent as part of the + * {@code DoFn}'s serialized state. + * + *
  • Initialize the state in each {@code DoFn} instance, in + * {@link DoFn#startBundle}. This is good if the initialization + * doesn't depend on any information known only by the main program or + * computed by earlier pipeline operations, but is the same for all + * instances of this {@code DoFn} for all program executions, say + * setting up empty caches or initializing constant data. + * + *
+ * + *

No Global Shared State

+ * + *

{@code ParDo} operations are intended to be able to run in + * parallel across multiple worker machines. This precludes easy + * sharing and updating mutable state across those machines. There is + * no support in the Google Cloud Dataflow system for communicating + * and synchronizing updates to shared state across worker machines, + * so programs should not access any mutable static variable state in + * their {@code DoFn}, without understanding that the Java processes + * for the main program and workers will each have its own independent + * copy of such state, and there won't be any automatic copying of + * that state across Java processes. All information should be + * communicated to {@code DoFn} instances via main and side inputs and + * serialized state, and all output should be communicated from a + * {@code DoFn} instance via main and side outputs, in the absence of + * external communication mechanisms written by user code. + * + *

Fault Tolerance

+ * + *

In a distributed system, things can fail: machines can crash, + * machines can be unable to communicate across the network, etc. + * While individual failures are rare, the larger the job, the greater + * the chance that something, somewhere, will fail. The Google Cloud + * Dataflow service strives to mask such failures automatically, + * principally by retrying failed {@code DoFn} batches. This means + * that a {@code DoFn} instance might process a batch partially, then + * crash for some reason, then be rerun (often on a different worker + * machine) on that same batch and on the same elements as before. + * Sometimes two or more {@code DoFn} instances will be running on the + * same batch simultaneously, with the system taking the results of + * the first instance to complete successfully. Consequently, the + * code in a {@code DoFn} needs to be written such that these + * duplicate (sequential or concurrent) executions do not cause + * problems. If the outputs of a {@code DoFn} are a pure function of + * its inputs, then this requirement is satisfied. However, if a + * {@code DoFn}'s execution has external side-effects, say performing + * updates to external HTTP services, then the {@code DoFn}'s code + * needs to take care to ensure that those updates are idempotent and + * that concurrent updates are acceptable. This property can be + * difficult to achieve, so it is advisable to strive to keep + * {@code DoFn}s as pure functions as much as possible. + * + *

Optimization

+ * + *

The Google Cloud Dataflow service automatically optimizes a + * pipeline before it is executed. A key optimization, fusion, + * relates to ParDo operations. If one ParDo operation produces a + * PCollection that is then consumed as the main input of another + * ParDo operation, the two ParDo operations will be fused + * together into a single ParDo operation and run in a single pass; + * this is "producer-consumer fusion". Similarly, if + * two or more ParDo operations have the same PCollection main input, + * they will be fused into a single ParDo which makes just one pass + * over the input PCollection; this is "sibling fusion". + * + *

If after fusion there are no more unfused references to a + * PCollection (e.g., one between a producer ParDo and a consumer + * ParDo), the PCollection itself is "fused away" and won't ever be + * written to disk, saving all the I/O and space expense of + * constructing it. + * + *

The Google Cloud Dataflow service applies fusion as much as + * possible, greatly reducing the cost of executing pipelines. As a + * result, it is essentially "free" to write ParDo operations in a + * vary modular, composable style, each ParDo operation doing one + * clear task, and stringing together sequences of ParDo operations to + * get the desired overall effect. Such programs can be easier to + * understand, easier to unit-test, easier to extend and evolve, and + * easier to reuse in new programs. The predefined library of + * PTransforms that come with Google Cloud Dataflow makes heavy use of + * this modular, composable style, trusting to the Google Cloud + * Dataflow service's optimizer to "flatten out" all the compositions + * into highly optimized stages. + * + * @see Using ParDo + */ +public class ParDo { + + /** + * Creates a {@code ParDo} {@code PTransform} with the given name. + * + *

See the discussion of Naming above for more explanation. + * + *

The resulting {@code PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@code PTransform}. + */ + public static Unbound named(String name) { + return new Unbound().named(name); + } + + /** + * Creates a {@code ParDo} {@code PTransform} with the given + * side inputs. + * + *

Side inputs are {@link PCollectionView}s, whose contents are + * computed during pipeline execution and then made accessible to + * {@code DoFn} code via {@link DoFn.Context#sideInput}. Each + * invocation of the {@code DoFn} receives the same values for these + * side inputs. + * + *

See the discussion of Side Inputs above for more explanation. + * + *

The resulting {@code PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@code PTransform}. + */ + public static Unbound withSideInputs(PCollectionView... sideInputs) { + return new Unbound().withSideInputs(sideInputs); + } + + /** + * Creates a {@code ParDo} with the given side inputs. + * + *

Side inputs are {@link PCollectionView}s, whose contents are + * computed during pipeline execution and then made accessible to + * {@code DoFn} code via {@link DoFn.Context#sideInput}. + * + *

See the discussion of Side Inputs above for more explanation. + * + *

The resulting {@code PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@code PTransform}. + */ + public static Unbound withSideInputs( + Iterable> sideInputs) { + return new Unbound().withSideInputs(sideInputs); + } + + /** + * Creates a multi-output {@code ParDo} {@code PTransform} whose + * output {@link PCollection}s will be referenced using the given main + * output and side output tags. + * + *

{@link TupleTag}s are used to name (with its static element + * type {@code T}) each main and side output {@code PCollection}. + * This {@code PTransform}'s {@link DoFn} emits elements to the main + * output {@code PCollection} as normal, using + * {@link DoFn.Context#output}. The {@code DoFn} emits elements to + * a side output {@code PCollection} using + * {@link DoFn.Context#sideOutput}, passing that side output's tag + * as an argument. The result of invoking this {@code PTransform} + * will be a {@link PCollectionTuple}, and any of the the main and + * side output {@code PCollection}s can be retrieved from it via + * {@link PCollectionTuple#get}, passing the output's tag as an + * argument. + * + *

See the discussion of Side Outputs above for more explanation. + * + *

The resulting {@code PTransform} is incomplete, and its input + * type is not yet bound. Use {@link ParDo.UnboundMulti#of} + * to specify the {@link DoFn} to invoke, which will also bind the + * input type of this {@code PTransform}. + */ + public static UnboundMulti withOutputTags( + TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new Unbound().withOutputTags(mainOutputTag, sideOutputTags); + } + + /** + * Creates a {@code ParDo} {@code PTransform} that will invoke the + * given {@link DoFn} function. + * + *

The resulting {@code PTransform}'s types have been bound, with the + * input being a {@code PCollection} and the output a + * {@code PCollection}, inferred from the types of the argument + * {@code DoFn}. It is ready to be applied, or further + * properties can be set on it first. + */ + public static Bound of(DoFn fn) { + return new Unbound().of(fn); + } + + /** + * An incomplete {@code ParDo} transform, with unbound input/output types. + * + *

Before being applied, {@link ParDo.Unbound#of} must be + * invoked to specify the {@link DoFn} to invoke, which will also + * bind the input/output types of this {@code PTransform}. + */ + public static class Unbound { + String name; + List> sideInputs = Collections.emptyList(); + + Unbound() {} + + Unbound(String name, + List> sideInputs) { + this.name = name; + this.sideInputs = sideInputs; + } + + /** + * Returns a new {@code ParDo} transform that's like this + * transform but with the specified name. Does not modify this + * transform. The resulting transform is still incomplete. + * + *

See the discussion of Naming above for more explanation. + */ + public Unbound named(String name) { + return new Unbound(name, sideInputs); + } + + /** + * Returns a new {@code ParDo} transform that's like this + * transform but with the specified side inputs. + * Does not modify this transform. The resulting transform is + * still incomplete. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Unbound withSideInputs(PCollectionView... sideInputs) { + return new Unbound(name, ImmutableList.copyOf(sideInputs)); + } + + /** + * Returns a new {@code ParDo} transform that's like this + * transform but with the specified side inputs. Does not modify + * this transform. The resulting transform is still incomplete. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Unbound withSideInputs( + Iterable> sideInputs) { + return new Unbound(name, ImmutableList.copyOf(sideInputs)); + } + + /** + * Returns a new multi-output {@code ParDo} transform that's like + * this transform but with the specified main and side output + * tags. Does not modify this transform. The resulting transform + * is still incomplete. + * + *

See the discussion of Side Outputs above and on + * {@link ParDo#withOutputTags} for more explanation. + */ + public UnboundMulti withOutputTags(TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new UnboundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags); + } + + /** + * Returns a new {@code ParDo} {@code PTransform} that's like this + * transform but which will invoke the given {@link DoFn} + * function, and which has its input and output types bound. Does + * not modify this transform. The resulting {@code PTransform} is + * sufficiently specified to be applied, but more properties can + * still be specified. + */ + public Bound of(DoFn fn) { + return new Bound<>(name, sideInputs, fn); + } + } + + /** + * A {@code PTransform} that, when applied to a {@code PCollection}, + * invokes a user-specified {@code DoFn} on all its elements, + * with all its outputs collected into an output + * {@code PCollection}. + * + *

A multi-output form of this transform can be created with + * {@link ParDo.Bound#withOutputTags}. + * + * @param the type of the (main) input {@code PCollection} elements + * @param the type of the (main) output {@code PCollection} elements + */ + public static class Bound + extends PTransform, PCollection> { + // Inherits name. + List> sideInputs; + DoFn fn; + + Bound(String name, + List> sideInputs, + DoFn fn) { + super(name); + this.sideInputs = sideInputs; + this.fn = fn; + } + + /** + * Returns a new {@code ParDo} {@code PTransform} that's like this + * {@code PTransform} but with the specified name. Does not + * modify this {@code PTransform}. + * + *

See the discussion of Naming above for more explanation. + */ + public Bound named(String name) { + return new Bound<>(name, sideInputs, fn); + } + + /** + * Returns a new {@code ParDo} {@code PTransform} that's like this + * {@code PTransform} but with the specified side inputs. Does not + * modify this {@code PTransform}. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Bound withSideInputs(PCollectionView... sideInputs) { + return new Bound<>(name, ImmutableList.copyOf(sideInputs), fn); + } + + /** + * Returns a new {@code ParDo} {@code PTransform} that's like this + * {@code PTransform} but with the specified side inputs. Does not + * modify this {@code PTransform}. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Bound withSideInputs( + Iterable> sideInputs) { + return new Bound<>(name, ImmutableList.copyOf(sideInputs), fn); + } + + /** + * Returns a new multi-output {@code ParDo} {@code PTransform} + * that's like this {@code PTransform} but with the specified main + * and side output tags. Does not modify this {@code PTransform}. + * + *

See the discussion of Side Outputs above and on + * {@link ParDo#withOutputTags} for more explanation. + */ + public BoundMulti withOutputTags(TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + + @Override + public PCollection apply(PCollection input) { + if (sideInputs == null) { + sideInputs = Collections.emptyList(); + } + return PCollection.createPrimitiveOutputInternal(getInput().getWindowingFn()) + .setTypeTokenInternal(fn.getOutputTypeToken()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return getPipeline().getCoderRegistry().getDefaultCoder( + fn.getOutputTypeToken(), + fn.getInputTypeToken(), + ((PCollection) getInput()).getCoder()); + } + + @Override + protected String getDefaultName() { + return StringUtils.approximateSimpleName(fn.getClass()); + } + + @Override + protected String getKindString() { return "ParDo"; } + + public DoFn getFn() { + return fn; + } + + public List> getSideInputs() { + return sideInputs; + } + } + + /** + * An incomplete multi-output {@code ParDo} transform, with unbound + * input type. + * + *

Before being applied, {@link ParDo.UnboundMulti#of} must be + * invoked to specify the {@link DoFn} to invoke, which will also + * bind the input type of this {@code PTransform}. + * + * @param the type of the main output {@code PCollection} elements + */ + public static class UnboundMulti { + String name; + List> sideInputs; + TupleTag mainOutputTag; + TupleTagList sideOutputTags; + + UnboundMulti(String name, + List> sideInputs, + TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + this.name = name; + this.sideInputs = sideInputs; + this.mainOutputTag = mainOutputTag; + this.sideOutputTags = sideOutputTags; + } + + /** + * Returns a new multi-output {@code ParDo} transform that's like + * this transform but with the specified name. Does not modify + * this transform. The resulting transform is still incomplete. + * + *

See the discussion of Naming above for more explanation. + */ + public UnboundMulti named(String name) { + return new UnboundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags); + } + + /** + * Returns a new multi-output {@code ParDo} transform that's like + * this transform but with the specified side inputs. Does not + * modify this transform. The resulting transform is still + * incomplete. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public UnboundMulti withSideInputs( + PCollectionView... sideInputs) { + return new UnboundMulti<>( + name, ImmutableList.copyOf(sideInputs), + mainOutputTag, sideOutputTags); + } + + /** + * Returns a new multi-output {@code ParDo} transform that's like + * this transform but with the specified side inputs. Does not + * modify this transform. The resulting transform is still + * incomplete. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public UnboundMulti withSideInputs( + Iterable> sideInputs) { + return new UnboundMulti<>( + name, ImmutableList.copyOf(sideInputs), + mainOutputTag, sideOutputTags); + } + + /** + * Returns a new multi-output {@code ParDo} {@code PTransform} + * that's like this transform but which will invoke the given + * {@link DoFn} function, and which has its input type bound. + * Does not modify this transform. The resulting + * {@code PTransform} is sufficiently specified to be applied, but + * more properties can still be specified. + */ + public BoundMulti of(DoFn fn) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + } + + /** + * A {@code PTransform} that, when applied to a + * {@code PCollection}, invokes a user-specified + * {@code DoFn} on all its elements, which can emit elements + * to any of the {@code PTransform}'s main and side output + * {@code PCollection}s, which are bundled into a result + * {@code PCollectionTuple}. + * + * @param the type of the (main) input {@code PCollection} elements + * @param the type of the main output {@code PCollection} elements + */ + public static class BoundMulti + extends PTransform, PCollectionTuple> { + // Inherits name. + List> sideInputs; + TupleTag mainOutputTag; + TupleTagList sideOutputTags; + DoFn fn; + + BoundMulti(String name, + List> sideInputs, + TupleTag mainOutputTag, + TupleTagList sideOutputTags, + DoFn fn) { + super(name); + this.sideInputs = sideInputs; + this.mainOutputTag = mainOutputTag; + this.sideOutputTags = sideOutputTags; + this.fn = fn; + } + + /** + * Returns a new multi-output {@code ParDo} {@code PTransform} + * that's like this {@code PTransform} but with the specified + * name. Does not modify this {@code PTransform}. + * + *

See the discussion of Naming above for more explanation. + */ + public BoundMulti named(String name) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + + /** + * Returns a new multi-output {@code ParDo} {@code PTransform} + * that's like this {@code PTransform} but with the specified side + * inputs. Does not modify this {@code PTransform}. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public BoundMulti withSideInputs( + PCollectionView... sideInputs) { + return new BoundMulti<>( + name, ImmutableList.copyOf(sideInputs), + mainOutputTag, sideOutputTags, fn); + } + + /** + * Returns a new multi-output {@code ParDo} {@code PTransform} + * that's like this {@code PTransform} but with the specified side + * inputs. Does not modify this {@code PTransform}. + * + *

See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public BoundMulti withSideInputs( + Iterable> sideInputs) { + return new BoundMulti<>( + name, ImmutableList.copyOf(sideInputs), + mainOutputTag, sideOutputTags, fn); + } + + + @Override + public PCollectionTuple apply(PCollection input) { + PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( + TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()), + getInput().getWindowingFn()); + + // The fn will likely be an instance of an anonymous subclass + // such as DoFn { }, thus will have a high-fidelity + // TypeToken for the output type. + outputs.get(mainOutputTag).setTypeTokenInternal(fn.getOutputTypeToken()); + + return outputs; + } + + @Override + protected Coder getDefaultOutputCoder() { + throw new RuntimeException( + "internal error: shouldn't be calling this on a multi-output ParDo"); + } + + @Override + protected String getDefaultName() { + return StringUtils.approximateSimpleName(fn.getClass()); + } + + @Override + protected String getKindString() { return "ParMultiDo"; } + + public DoFn getFn() { + return fn; + } + + public TupleTag getMainOutputTag() { + return mainOutputTag; + } + + public List> getSideInputs() { + return sideInputs; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateSingleHelper(transform, context); + } + }); + } + + private static void evaluateSingleHelper( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + TupleTag mainOutputTag = new TupleTag<>("out"); + + DirectModeExecutionContext executionContext = new DirectModeExecutionContext(); + + DoFnRunner fnRunner = + evaluateHelper(transform.fn, context.getStepName(transform), + transform.getInput(), transform.sideInputs, + mainOutputTag, new ArrayList>(), + context, executionContext); + + context.setPCollectionValuesWithMetadata( + transform.getOutput(), + executionContext.getOutput(mainOutputTag)); + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + BoundMulti.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + BoundMulti transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateMultiHelper(transform, context); + } + }); + } + + private static void evaluateMultiHelper( + BoundMulti transform, + DirectPipelineRunner.EvaluationContext context) { + + DirectModeExecutionContext executionContext = new DirectModeExecutionContext(); + + DoFnRunner fnRunner = + evaluateHelper(transform.fn, context.getStepName(transform), + transform.getInput(), transform.sideInputs, + transform.mainOutputTag, transform.sideOutputTags.getAll(), + context, executionContext); + + for (Map.Entry, PCollection> entry + : transform.getOutput().getAll().entrySet()) { + TupleTag tag = (TupleTag) entry.getKey(); + @SuppressWarnings("unchecked") + PCollection pc = (PCollection) entry.getValue(); + + context.setPCollectionValuesWithMetadata( + pc, + (tag == transform.mainOutputTag + ? executionContext.getOutput(tag) + : executionContext.getSideOutput(tag))); + } + } + + private static DoFnRunner evaluateHelper( + DoFn doFn, + String name, + PCollection input, + List> sideInputs, + TupleTag mainOutputTag, + List> sideOutputTags, + DirectPipelineRunner.EvaluationContext context, + DirectModeExecutionContext executionContext) { + // TODO: Run multiple shards? + DoFn fn = context.ensureSerializable(doFn); + + PTuple sideInputValues = PTuple.empty(); + for (PCollectionView view : sideInputs) { + sideInputValues = sideInputValues.and( + view.getTagInternal(), + context.getPCollectionView(view)); + } + + DoFnRunner fnRunner = + DoFnRunner.createWithListOutputs( + context.getPipelineOptions(), + fn, + sideInputValues, + mainOutputTag, + sideOutputTags, + executionContext.getStepContext(name), + context.getAddCounterMutator()); + + fnRunner.startBundle(); + + for (DirectPipelineRunner.ValueWithMetadata elem + : context.getPCollectionValuesWithMetadata(input)) { + executionContext.setKey(elem.getKey()); + fnRunner.processElement((WindowedValue) elem.getWindowedValue()); + } + + fnRunner.finishBundle(); + + return fnRunner; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java new file mode 100644 index 000000000000..74a1359aa5ed --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import java.io.Serializable; + +/** + * {@code Partition} takes a {@code PCollection} and a + * {@code PartitionFn}, uses the {@code PartitionFn} to split the + * elements of the input {@code PCollection} into {@code N} partitions, and + * returns a {@code PCollectionList} that bundles {@code N} + * {@code PCollection}s containing the split elements. + * + *

Example of use: + *

 {@code
+ * PCollection students = ...;
+ * // Split students up into 10 partitions, by percentile:
+ * PCollectionList studentsByPercentile =
+ *     students.apply(Partition.of(10, new PartitionFn() {
+ *         public int partitionFor(Student student, int numPartitions) {
+ *             return student.getPercentile()  // 0..99
+ *                  * numPartitions / 100;
+ *         }}))
+ * for (int i = 0; i < 10; i++) {
+ *   PCollection partition = studentsByPercentile.get(i);
+ *   ...
+ * }
+ * } 
+ * + *

By default, the {@code Coder} of each of the + * {@code PCollection}s in the output {@code PCollectionList} is the + * same as the {@code Coder} of the input {@code PCollection}. + * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and each output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ +public class Partition + extends PTransform, PCollectionList> { + + /** + * A function object that chooses an output partition for an element. + * + * @param the type of the elements being partitioned + */ + public interface PartitionFn extends Serializable { + /** + * Chooses the partition into which to put the given element. + * + * @param elem the element to be partitioned + * @param numPartitions the total number of partitions ({@code >= 1}) + * @return index of the selected partition (in the range + * {@code [0..numPartitions-1]}) + */ + public int partitionFor(T elem, int numPartitions); + } + + /** + * Returns a new {@code Partition} {@code PTransform} that divides + * its input {@code PCollection} into the given number of partitions, + * using the given partitioning function. + * + * @param numPartitions the number of partitions to divide the input + * {@code PCollection} into + * @param partitionFn the function to invoke on each element to + * choose its output partition + * @throws IllegalArgumentException if {@code numPartitions <= 0} + */ + public static Partition of( + int numPartitions, PartitionFn partitionFn) { + return new Partition<>(new PartitionDoFn(numPartitions, partitionFn)); + } + + @Override + public PCollectionList apply(PCollection in) { + final TupleTagList outputTags = partitionDoFn.getOutputTags(); + + PCollectionTuple outputs = in.apply( + ParDo + .withOutputTags(new TupleTag(){}, outputTags) + .of(partitionDoFn)); + + PCollectionList pcs = PCollectionList.empty(in.getPipeline()); + Coder coder = in.getCoder(); + + for (TupleTag outputTag : outputTags.getAll()) { + // All the tuple tags are actually TupleTag + // And all the collections are actually PCollection + @SuppressWarnings("unchecked") + TupleTag typedOutputTag = (TupleTag) outputTag; + pcs = pcs.and(outputs.get(typedOutputTag).setCoder(coder)); + } + return pcs; + } + + ///////////////////////////////////////////////////////////////////////////// + + private final transient PartitionDoFn partitionDoFn; + + private Partition(PartitionDoFn partitionDoFn) { + this.partitionDoFn = partitionDoFn; + } + + private static class PartitionDoFn extends DoFn { + private final int numPartitions; + private final PartitionFn partitionFn; + private final TupleTagList outputTags; + + /** + * Constructs a PartitionDoFn. + * + * @throws IllegalArgumentException if {@code numPartitions <= 0} + */ + public PartitionDoFn( + int numPartitions, PartitionFn partitionFn) { + if (numPartitions <= 0) { + throw new IllegalArgumentException("numPartitions must be > 0"); + } + + this.numPartitions = numPartitions; + this.partitionFn = partitionFn; + + TupleTagList buildOutputTags = TupleTagList.empty(); + for (int partition = 0; partition < numPartitions; partition++) { + buildOutputTags = buildOutputTags.and(new TupleTag()); + } + outputTags = buildOutputTags; + } + + public TupleTagList getOutputTags() { + return outputTags; + } + + @Override + public void processElement(ProcessContext c) { + T1 input = c.element(); + int partition = partitionFn.partitionFor(input, numPartitions); + if (0 <= partition && partition < numPartitions) { + c.sideOutput((TupleTag) outputTags.get(partition), input); + } else { + throw new IndexOutOfBoundsException( + "Partition function returned out of bounds index: " + + partition + " not in [0.." + numPartitions + ")"); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RateLimiting.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RateLimiting.java new file mode 100644 index 000000000000..2124acfbb84a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RateLimiting.java @@ -0,0 +1,336 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.api.client.util.Throwables; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.reflect.TypeToken; +import com.google.common.util.concurrent.RateLimiter; + +import org.joda.time.Instant; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Provides rate-limiting of user functions, using threaded execution and a + * {@link com.google.common.util.concurrent.RateLimiter} to process elements + * at the desired rate. + * + *

For example, to limit each worker to 10 requests per second: + *

{@code
+ * PCollection data = ...;
+ * data.apply(
+ *   RateLimiting.perWorker(new MyDoFn())
+ *               .withRateLimit(10)));
+ * }
+ * + *

An uncaught exception from the wrapped DoFn will result in the exception + * being rethrown in later calls to {@link RateLimitingDoFn#processElement} + * or a call to {@link RateLimitingDoFn#finishBundle}. + * + *

Rate limiting is provided as a PTransform + * ({@link RateLimitingTransform}), and also as a {@code DoFn} + * ({@link RateLimitingDoFn}). + */ +public class RateLimiting { + + /** + * Creates a new per-worker rate-limiting transform for the given + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn}. + * + *

The default behavior is to process elements with multiple threads, + * but no rate limit is applied. + * + *

Use {@link RateLimitingTransform#withRateLimit} to limit the processing + * rate, and {@link RateLimitingTransform#withMaxParallelism} to control the + * maximum concurrent processing limit. + * + *

Aside from the above, the {@code DoFn} will be executed in the same manner + * as in {@link ParDo}. + * + *

Rate limiting is applied independently per-worker. + */ + public static RateLimitingTransform perWorker(DoFn doFn) { + return new RateLimitingTransform<>(doFn); + } + + /** + * A {@link PTransform} which applies rate limiting to a {@link DoFn}. + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ + public static class RateLimitingTransform + extends PTransform, PCollection> { + private final DoFn doFn; + private double rate = 0.0; + // TODO: set default based on num cores, or based on rate limit? + private int maxParallelism = DEFAULT_MAX_PARALLELISM; + + public RateLimitingTransform(DoFn doFn) { + this.doFn = doFn; + } + + /** + * Modifies this {@code RateLimitingTransform}, specifying a maximum + * per-worker element processing rate. + * + *

A rate of {@code N} corresponds to {@code N} elements per second. + * This rate is on a per-worker basis, so the overall rate of the job + * depends upon the number of workers. + * + *

This rate limit may not be reachable unless there is sufficient + * parallelism. + * + *

A rate of <= 0.0 disables rate limiting. + */ + public RateLimitingTransform withRateLimit( + double maxElementsPerSecond) { + this.rate = maxElementsPerSecond; + return this; + } + + /** + * Modifies this {@code RateLimitingTransform}, specifying a maximum + * per-worker parallelism. + * + *

This determines how many concurrent elements will be processed by the + * wrapped {@code DoFn}. + * + *

The desired amount of parallelism depends upon the type of work. For + * CPU-intensive work, a good starting point is to use the number of cores: + * {@code Runtime.getRuntime().availableProcessors()}. + */ + public RateLimitingTransform withMaxParallelism(int max) { + this.maxParallelism = max; + return this; + } + + @Override + public PCollection apply(PCollection input) { + return input.apply( + ParDo.of(new RateLimitingDoFn<>(doFn, rate, maxParallelism))); + } + } + + /** + * A rate-limiting {@code DoFn} wrapper. + * + * @see RateLimiting#perWorker(DoFn) + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ + public static class RateLimitingDoFn extends DoFn { + private static final Logger LOG = LoggerFactory.getLogger(RateLimitingDoFn.class); + + public RateLimitingDoFn(DoFn doFn, double rateLimit, + int maxParallelism) { + this.doFn = doFn; + this.rate = rateLimit; + this.maxParallelism = maxParallelism; + } + + @Override + public void startBundle(Context c) throws Exception { + doFn.startBundle(c); + + if (rate > 0.0) { + limiter = RateLimiter.create(rate); + } + executor = Executors.newCachedThreadPool(); + workTickets = new Semaphore(maxParallelism); + failure = new AtomicReference<>(); + } + + @Override + public void processElement(final ProcessContext c) throws Exception { + // Apply rate limiting up front, controlling the availability of work for + // the thread pool. This allows us to use an auto-scaling thread pool, + // which adapts the parallelism to the available work. + // The semaphore is used to avoid overwhelming the executor, by bounding + // the number of outstanding elements. + if (limiter != null) { + limiter.acquire(); + } + try { + workTickets.acquire(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while scheduling work", e); + } + + if (failure.get() != null) { + throw Throwables.propagate(failure.get()); + } + + executor.submit(new Runnable() { + @Override + public void run() { + try { + doFn.processElement(new WrappedContext(c)); + } catch (Throwable t) { + failure.compareAndSet(null, t); + Throwables.propagateIfPossible(t); + throw new AssertionError("Unexpected checked exception: " + t); + } finally { + workTickets.release(); + } + } + }); + } + + @Override + public void finishBundle(Context c) throws Exception { + executor.shutdown(); + // Log a periodic progress report until the queue has drained. + while (true) { + try { + if (executor.awaitTermination(30, TimeUnit.SECONDS)) { + if (failure.get() != null) { + // Handle failure propagation outside of the try/catch block. + break; + } + doFn.finishBundle(c); + return; + } + int outstanding = workTickets.getQueueLength() + + maxParallelism - workTickets.availablePermits(); + LOG.info("RateLimitingDoFn backlog: {}", outstanding); + } catch (InterruptedException e) { + throw Throwables.propagate(e); + } + } + + throw Throwables.propagate(failure.get()); + } + + @Override + TypeToken getInputTypeToken() { + return doFn.getInputTypeToken(); + } + + @Override + TypeToken getOutputTypeToken() { + return doFn.getOutputTypeToken(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Wraps a DoFn context, forcing single-thread output so that threads don't + * propagate through to downstream functions. + */ + private class WrappedContext extends ProcessContext { + private final ProcessContext context; + + WrappedContext(ProcessContext context) { + this.context = context; + } + + @Override + public I element() { + return context.element(); + } + + @Override + public KeyedState keyedState() { + return context.keyedState(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return context.sideInput(view); + } + + @Override + public void output(O output) { + synchronized (RateLimitingDoFn.this) { + context.output(output); + } + } + + @Override + public void outputWithTimestamp(O output, Instant timestamp) { + synchronized (RateLimitingDoFn.this) { + context.outputWithTimestamp(output, timestamp); + } + } + + @Override + public void sideOutput(TupleTag tag, T output) { + synchronized (RateLimitingDoFn.this) { + context.sideOutput(tag, output); + } + } + + @Override + public Aggregator createAggregator( + String name, Combine.CombineFn combiner) { + return context.createAggregator(name, combiner); + } + + @Override + public Aggregator createAggregator( + String name, SerializableFunction, AO> combiner) { + return context.createAggregator(name, combiner); + } + + @Override + public Instant timestamp() { + return context.timestamp(); + } + + @Override + public Collection windows() { + return context.windows(); + } + } + + private final DoFn doFn; + private double rate; + private int maxParallelism; + + private transient RateLimiter limiter; + private transient ExecutorService executor; + private transient Semaphore workTickets; + private transient AtomicReference failure; + } + + /** + * Default maximum for number of concurrent elements to process. + */ + @VisibleForTesting + static final int DEFAULT_MAX_PARALLELISM = 16; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java new file mode 100644 index 000000000000..0e4f21f75b78 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code RemoveDuplicates} takes a {@code PCollection} and + * returns a {@code PCollection} that has all the elements of the + * input but with duplicate elements removed such that each element is + * unique within each window. + * + *

Two values of type {@code T} are compared for equality not by + * regular Java {@link Object#equals}, but instead by first encoding + * each of the elements using the {@code PCollection}'s {@code Coder}, and then + * comparing the encoded bytes. This admits efficient parallel + * evaluation. + * + *

By default, the {@code Coder} of the output {@code PCollection} + * is the same as the {@code Coder} of the input {@code PCollection}. + * + *

Each output element is in the same window as its corresponding input + * element, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * as the input. + * + *

Does not preserve any order the input PCollection might have had. + * + *

Example of use: + *

 {@code
+ * PCollection words = ...;
+ * PCollection uniqueWords =
+ *     words.apply(RemoveDuplicates.create());
+ * } 
+ * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ +public class RemoveDuplicates extends PTransform, + PCollection> { + /** + * Returns a {@code RemoveDuplicates} {@code PTransform}. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ + public static RemoveDuplicates create() { + return new RemoveDuplicates<>(); + } + + private RemoveDuplicates() { } + + @Override + public PCollection apply(PCollection in) { + return + in + .apply(ParDo.named("CreateIndex") + .of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), (Void) null)); + } + })) + .apply(Combine.perKey( + new SerializableFunction, Void>() { + @Override + public Void apply(Iterable iter) { + return null; // ignore input + } + })) + .apply(Keys.create()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java new file mode 100644 index 000000000000..832cc996ea76 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * {@code PTransform}s for taking samples of the elements in a + * {@code PCollection}, or samples of the values associated with each + * key in a {@code PCollection} of {@code KV}s. + **/ +public class Sample { + /** + * Returns a {@code PTransform} that takes a {@code PCollection}, + * selects {@code sampleSize} elements, uniformly at random, and returns a + * {@code PCollection>} containing the selected elements. + * If the input {@code PCollection} has fewer than + * {@code sampleSize} elements, then the output {@code Iterable} + * will be all the input's elements. + * + *

Example of use: + *

 {@code
+   * PCollection pc = ...;
+   * PCollection> sampleOfSize10 =
+   *     pc.apply(Sample.fixedSizeGlobally(10));
+   * } 
+ * + * @param sampleSize the number of elements to select; must be {@code >= 0} + * @param the type of the elements + */ + public static PTransform, PCollection>> + fixedSizeGlobally(int sampleSize) { + return Combine.globally(new FixedSizedSampleFn(sampleSize)); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to a sample of {@code sampleSize} values + * associated with that key in the input {@code PCollection}, taken + * uniformly at random. If a key in the input {@code PCollection} + * has fewer than {@code sampleSize} values associated with it, then + * the output {@code Iterable} associated with that key will be + * all the values associated with that key in the input + * {@code PCollection}. + * + *

Example of use: + *

 {@code
+   * PCollection> pc = ...;
+   * PCollection>> sampleOfSize10PerKey =
+   *     pc.apply(Sample.fixedSizePerKey());
+   * } 
+ * + * @param sampleSize the number of values to select for each + * distinct key; must be {@code >= 0} + * @param the type of the keys + * @param the type of the values + */ + public static PTransform>, + PCollection>>> + fixedSizePerKey(int sampleSize) { + return Combine.perKey(new FixedSizedSampleFn(sampleSize)); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code CombineFn} that computes a fixed-size sample of a + * collection of values. + * + * @param the type of the elements + */ + public static class FixedSizedSampleFn + extends CombineFn>.Heap, Iterable> { + private final Top.TopCombineFn> topCombineFn; + private final Random rand = new Random(); + + private FixedSizedSampleFn(int sampleSize) { + if (sampleSize < 0) { + throw new IllegalArgumentException("sample size must be >= 0"); + } + topCombineFn = new Top.TopCombineFn<>(sampleSize, + new KV.OrderByKey()); + } + + @Override + public Top.TopCombineFn>.Heap createAccumulator() { + return topCombineFn.createAccumulator(); + } + + @Override + public void addInput(Top.TopCombineFn>.Heap accumulator, + T input) { + accumulator.addInput(KV.of(rand.nextInt(), input)); + } + + @Override + public Top.TopCombineFn>.Heap mergeAccumulators( + Iterable>.Heap> accumulators) { + return topCombineFn.mergeAccumulators(accumulators); + } + + @Override + public Iterable extractOutput( + Top.TopCombineFn>.Heap accumulator) { + List out = new ArrayList<>(); + for (KV element : accumulator.extractOutput()) { + out.add(element.getValue()); + } + return out; + } + + @Override + public Coder>.Heap> getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return topCombineFn.getAccumulatorCoder( + registry, KvCoder.of(BigEndianIntegerCoder.of(), inputCoder)); + } + + @Override + public Coder> getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) { + return IterableCoder.of(inputCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java new file mode 100644 index 000000000000..3d538faa54d8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * A {@code Serializable} {@code Comparator}. + * + * @param type of values being compared + */ +public interface SerializableComparator extends Comparator, Serializable { +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java new file mode 100644 index 000000000000..857491a11fe8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import java.io.Serializable; + +/** + * A function that computes an output value based on an input value, + * and is {@link Serializable}. + * + * @param input value type + * @param output value type + */ +public interface SerializableFunction extends Serializable { + /** Returns the result of invoking this function on the given input. */ + public O apply(I input); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java new file mode 100644 index 000000000000..e925e4a5cc90 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +/** + * {@code PTransform}s for computing the sum of the elements in a + * {@code PCollection}, or the sum of the values associated with + * each key in a {@code PCollection} of {@code KV}s. + * + *

Example 1: get the sum of a {@code PCollection} of {@code Double}s. + *

 {@code
+ * PCollection input = ...;
+ * PCollection sum = input.apply(Sum.doublesGlobally());
+ * } 
+ * + *

Example 2: calculate the sum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

 {@code
+ * PCollection> input = ...;
+ * PCollection> sumPerKey = input
+ *     .apply(Sum.integersPerKey());
+ * } 
+ */ +public class Sum { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + Combine.Globally combine = Combine + .globally(new SumIntegerFn()); + combine.setName("Sum"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey integersPerKey() { + Combine.PerKey combine = Combine + .perKey(new SumIntegerFn()); + combine.setName("Sum.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + Combine.Globally combine = Combine.globally(new SumLongFn()); + combine.setName("Sum"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey longsPerKey() { + Combine.PerKey combine = Combine + .perKey(new SumLongFn()); + combine.setName("Sum.PerKey"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + Combine.Globally combine = Combine + .globally(new SumDoubleFn()); + combine.setName("Sum"); + return combine; + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey doublesPerKey() { + Combine.PerKey combine = Combine + .perKey(new SumDoubleFn()); + combine.setName("Sum.PerKey"); + return combine; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Integer}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumIntegerFn + implements SerializableFunction, Integer> { + @Override + public Integer apply(Iterable input) { + int sum = 0; + for (int value : input) { + sum += value; + } + return sum; + } + } + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Long}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumLongFn + implements SerializableFunction, Long> { + @Override + public Long apply(Iterable input) { + long sum = 0; + for (long value : input) { + sum += value; + } + return sum; + } + } + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Double}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumDoubleFn + implements SerializableFunction, Double> { + @Override + public Double apply(Iterable input) { + double sum = 0; + for (double value : input) { + sum += value; + } + return sum; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java new file mode 100644 index 000000000000..1f63808fc223 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java @@ -0,0 +1,489 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; + +/** + * {@code PTransform}s for finding the largest (or smallest) set + * of elements in a {@code PCollection}, or the largest (or smallest) + * set of values associated with each key in a {@code PCollection} of + * {@code KV}s. + */ +public class Top { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the largest {@code count} elements of the input + * {@code PCollection}, in decreasing order, sorted using the + * given {@code Comparator}. The {@code Comparator} must also + * be {@code Serializable}. + * + *

If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting + * {@code List}, albeit in sorted order. + * + *

All the elements of the result's {@code List} + * must fit into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection students = ...;
+   * PCollection> top10Students =
+   *     students.apply(Top.of(10, new CompareStudentsByAvgGrade()));
+   * } 
+ * + *

By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

See also {@link #smallest} and {@link #largest}, which sort + * {@code Comparable} elements using their natural ordering. + * + *

See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey} which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static & Serializable> + PTransform, PCollection>> of(int count, C compareFn) { + return Combine.globally(new TopCombineFn<>(count, compareFn)) + .withName("Top"); + + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the smallest {@code count} elements of the input + * {@code PCollection}, in increasing order, sorted according to + * their natural order. + * + *

If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting {@code PCollection}'s + * {@code List}, albeit in sorted order. + * + *

All the elements of the result {@code List} + * must fit into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection values = ...;
+   * PCollection> smallest10Values = values.apply(Top.smallest(10));
+   * } 
+ * + *

By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

See also {@link #largest}. + * + *

See also {@link #of}, which sorts using a user-specified + * {@code Comparator} function. + * + *

See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey} which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static > + PTransform, PCollection>> smallest(int count) { + return Combine.globally(new TopCombineFn<>(count, new Smallest())) + .withName("Top.Smallest"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the largest {@code count} elements of the input + * {@code PCollection}, in decreasing order, sorted according to + * their natural order. + * + *

If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting {@code PCollection}'s + * {@code List}, albeit in sorted order. + * + *

All the elements of the result's {@code List} + * must fit into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection values = ...;
+   * PCollection> largest10Values = values.apply(Top.largest(10));
+   * } 
+ * + *

By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

See also {@link #smallest}. + * + *

See also {@link #of}, which sorts using a user-specified + * {@code Comparator} function. + * + *

See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey} which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static > + PTransform, PCollection>> largest(int count) { + return Combine.globally(new TopCombineFn<>(count, new Largest())) + .withName("Top.Largest"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the largest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in decreasing order, sorted using + * the given {@code Comparator}. The + * {@code Comparator} must also be {@code Serializable}. + * + *

If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection> studentsBySchool = ...;
+   * PCollection>> top10StudentsBySchool =
+   *     studentsBySchool.apply(
+   *         Top.perKey(10, new CompareStudentsByAvgGrade()));
+   * } 
+ * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

See also {@link #smallestPerKey} and {@link #largestPerKey}, + * which sort {@code Comparable} values using their natural + * ordering. + * + *

See also {@link #of}, {@link #smallest}, and {@link #largest} + * which take a {@code PCollection} and return the top elements. + */ + public static & Serializable> + PTransform>, PCollection>>> + perKey(int count, C compareFn) { + return Combine.perKey( + new TopCombineFn<>(count, compareFn).asKeyedFn()) + .withName("Top.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the smallest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in increasing order, sorted + * according to their natural order. + * + *

If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection> keyedValues = ...;
+   * PCollection>> smallest10ValuesPerKey =
+   *     keyedValues.apply(Top.smallestPerKey(10));
+   * } 
+ * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

See also {@link #largestPerKey}. + * + *

See also {@link #perKey}, which sorts values using a user-specified + * {@code Comparator} function. + * + *

See also {@link #of}, {@link #smallest}, and {@link #largest} + * which take a {@code PCollection} and return the top elements. + */ + public static > + PTransform>, PCollection>>> + smallestPerKey(int count) { + return Combine.perKey( + new TopCombineFn<>(count, new Smallest()).asKeyedFn()) + .withName("Top.SmallestPerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the largest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in decreasing order, sorted + * according to their natural order. + * + *

If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

Example of use: + *

 {@code
+   * PCollection> keyedValues = ...;
+   * PCollection>> largest10ValuesPerKey =
+   *     keyedValues.apply(Top.largestPerKey(10));
+   * } 
+ * + *

By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

See also {@link #smallestPerKey}. + * + *

See also {@link #perKey}, which sorts values using a user-specified + * {@code Comparator} function. + * + *

See also {@link #of}, {@link #smallest}, and {@link #largest} + * which take a {@code PCollection} and return the top elements. + */ + public static > + PTransform>, PCollection>>> + largestPerKey(int count) { + return Combine.perKey( + new TopCombineFn<>(count, new Largest()).asKeyedFn()) + .withName("Top.LargestPerKey"); + } + + + //////////////////////////////////////////////////////////////////////////// + + /** + * {@code CombineFn} for {@code Top} transforms that combines a + * bunch of {@code T}s into a single {@code count}-long + * {@code List}, using {@code compareFn} to choose the largest + * {@code T}s. + * + * @param type of element being compared + */ + public static class TopCombineFn + extends AccumulatingCombineFn.Heap, List> { + + private final int count; + private final Comparator compareFn; + + public & Serializable> TopCombineFn( + int count, C compareFn) { + if (count < 0) { + throw new IllegalArgumentException("count must be >= 0"); + } + this.count = count; + this.compareFn = compareFn; + } + + class Heap + // TODO: Why do I have to fully qualify the + // Accumulator class here? + extends AccumulatingCombineFn.Heap, List> + .Accumulator { + + // Exactly one of these should be set. + private List asList; // ordered largest first + private PriorityQueue asQueue; // head is smallest + + private Heap(List asList) { + this.asList = asList; + } + + @Override + public void addInput(T value) { + addInputInternal(value); + } + + private boolean addInputInternal(T value) { + if (count == 0) { + // Don't add anything. + return false; + } + + if (asQueue == null) { + asQueue = new PriorityQueue<>(count, compareFn); + for (T item : asList) { + asQueue.add(item); + } + asList = null; + } + + if (asQueue.size() < count) { + asQueue.add(value); + return true; + } else if (compareFn.compare(value, asQueue.peek()) > 0) { + asQueue.poll(); + asQueue.add(value); + return true; + } else { + return false; + } + } + + @Override + public void mergeAccumulator(Heap accumulator) { + for (T value : accumulator.asList()) { + if (!addInputInternal(value)) { + // The list is ordered, remainder will also all be smaller. + break; + } + } + } + + @Override + public List extractOutput() { + return asList(); + } + + private List asList() { + if (asList == null) { + int index = asQueue.size(); + @SuppressWarnings("unchecked") + T[] ordered = (T[]) new Object[index]; + while (!asQueue.isEmpty()) { + index--; + ordered[index] = asQueue.poll(); + } + asList = Arrays.asList(ordered); + asQueue = null; + } + return asList; + } + } + + @Override + public Heap createAccumulator() { + return new Heap(new ArrayList()); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new HeapCoder(inputCoder); + } + + private class HeapCoder extends CustomCoder { + private final Coder> listCoder; + + public HeapCoder(Coder inputCoder) { + listCoder = ListCoder.of(inputCoder); + } + + @Override + public void encode(Heap value, OutputStream outStream, + Context context) throws CoderException, IOException { + listCoder.encode(value.asList(), outStream, context); + } + + @Override + public Heap decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + return new Heap(listCoder.decode(inStream, context)); + } + + @Override + public boolean isDeterministic() { + return listCoder.isDeterministic(); + } + + @Override + public boolean isRegisterByteSizeObserverCheap( + Heap value, Context context) { + return listCoder.isRegisterByteSizeObserverCheap( + value.asList(), context); + } + + @Override + public void registerByteSizeObserver( + Heap value, ElementByteSizeObserver observer, Context context) + throws Exception { + listCoder.registerByteSizeObserver(value.asList(), observer, context); + } + }; + } + + /** + * {@code Serializable} {@code Comparator} that that uses the + * compared elements' natural ordering. + */ + public static class Largest> + implements Comparator, Serializable { + @Override + public int compare(T a, T b) { + return a.compareTo(b); + } + } + + /** + * {@code Serializable} {@code Comparator} that that uses the + * reverse of the compared elements' natural ordering. + */ + public static class Smallest> + implements Comparator, Serializable { + @Override + public int compare(T a, T b) { + return b.compareTo(a); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java new file mode 100644 index 000000000000..ae008b196ad3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code Values} takes a {@code PCollection} of {@code KV}s and + * returns a {@code PCollection} of the values. + * + *

Example of use: + *

 {@code
+ * PCollection> wordCounts = ...;
+ * PCollection counts = wordCounts.apply(Values.create());
+ * } 
+ * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + *

See also {@link Keys}. + * + * @param the type of the values in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ +public class Values extends PTransform>, + PCollection> { + /** + * Returns a {@code Values} {@code PTransform}. + * + * @param the type of the values in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ + public static Values create() { + return new Values<>(); + } + + private Values() { } + + @Override + public PCollection apply(PCollection> in) { + return + in.apply(ParDo.named("Values") + .of(new DoFn, V>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getValue()); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java new file mode 100644 index 000000000000..d3bb86388870 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java @@ -0,0 +1,211 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValueBase; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Transforms for creating {@link PCollectionView}s from {@link PCollection}s, + * for consuming the contents of those {@link PCollection}s as side inputs + * to {@link ParDo} transforms. + */ +public class View { + + // Do not instantiate + private View() { } + + /** + * Returns a {@link AsSingleton} transform that takes a singleton + * {@link PCollection} as input and produces a {@link PCollectionView} + * of the single value, to be consumed as a side input. + * + *

If the input {@link PCollection} is empty, + * throws {@link NoSuchElementException} in the consuming + * {@link DoFn}. + * + *

If the input {@link PCollection} contains more than one + * element, throws {@link IllegalArgumentException} in the + * consuming {@link DoFn}. + */ + public static AsSingleton asSingleton() { + return new AsSingleton<>(); + } + + /** + * Returns a {@link AsIterable} that takes a + * {@link PCollection} as input and produces a {@link PCollectionView} + * of the values, to be consumed as an iterable side input. + */ + public static AsIterable asIterable() { + return new AsIterable<>(); + } + + + /** + * A {@PTransform} that produces a {@link PCollectionView} of a singleton {@link PCollection} + * yielding the single element it contains. + * + *

Instantiate via {@link View.asIterable}. + */ + public static class AsIterable extends PTransform< + PCollection, + PCollectionView, Iterable>>> { + + private AsIterable() { } + + @Override + public PCollectionView, Iterable>> apply( + PCollection input) { + return input.apply( + new CreatePCollectionView, Iterable>>( + new IterablePCollectionView(input.getPipeline()))); + } + } + + /** + * A {@PTransform} that produces a {@link PCollectionView} of a singleton {@link PCollection} + * yielding the single element it contains. + * + *

Instantiate via {@link View.asIterable}. + */ + public static class AsSingleton + extends PTransform, PCollectionView>> { + + private AsSingleton() { } + + @Override + public PCollectionView> apply(PCollection input) { + return input.apply( + new CreatePCollectionView>( + new SingletonPCollectionView(input.getPipeline()))); + } + + } + + + //////////////////////////////////////////////////////////////////////////// + // Internal details below + + /** + * Creates a primitive PCollectionView. + * + *

For internal use only. + * + * @param The type of the elements of the input PCollection + * @param The type associated with the PCollectionView used as a side input + * @param The type associated with a windowed side input from the + * PCollectionView + */ + public static class CreatePCollectionView + extends PTransform, PCollectionView> { + + private PCollectionView view; + + public CreatePCollectionView(PCollectionView view) { + this.view = view; + } + + @Override + public PCollectionView apply(PCollection input) { + return view; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + CreatePCollectionView.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + CreatePCollectionView transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateTyped(transform, context); + } + + private void evaluateTyped( + CreatePCollectionView transform, + DirectPipelineRunner.EvaluationContext context) { + List> elems = + context.getPCollectionWindowedValues(transform.getInput()); + context.setPCollectionView(transform.getOutput(), elems); + } + }); + } + } + + private static class SingletonPCollectionView + extends PCollectionViewBase> { + + public SingletonPCollectionView(Pipeline pipeline) { + setPipelineInternal(pipeline); + } + + @Override + public T fromIterableInternal(Iterable> contents) { + try { + return (T) Iterables.getOnlyElement(contents).getValue(); + } catch (NoSuchElementException exc) { + throw new NoSuchElementException( + "Empty PCollection accessed as a singleton view."); + } catch (IllegalArgumentException exc) { + throw new IllegalArgumentException( + "PCollection with more than one element " + + "accessed as a singleton view."); + } + } + } + + private static class IterablePCollectionView + extends PCollectionViewBase, Iterable>> { + + public IterablePCollectionView(Pipeline pipeline) { + setPipelineInternal(pipeline); + } + + @Override + public Iterable fromIterableInternal(Iterable> contents) { + return Iterables.transform(contents, new Function, T>() { + @Override + public T apply(WindowedValue input) { + return (T) input.getValue(); + } + }); + } + } + + private abstract static class PCollectionViewBase + extends PValueBase + implements PCollectionView { + + @Override + public TupleTag>> getTagInternal() { + return tag; + } + + private TupleTag>> tag = new TupleTag<>(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java new file mode 100644 index 000000000000..1754c20a7916 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.reflect.TypeToken; + +/** + * {@code WithKeys} takes a {@code PCollection}, and either a + * constant key of type {@code K} or a function from {@code V} to + * {@code K}, and returns a {@code PCollection>}, where each + * of the values in the input {@code PCollection} has been paired with + * either the constant key or a key computed from the value. + * + *

Example of use: + *

 {@code
+ * PCollection words = ...;
+ * PCollection> lengthsToWords =
+ *     words.apply(WithKeys.of(new SerializableFunction() {
+ *         public Integer apply(String s) { return s.length(); } }));
+ * } 
+ * + *

Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn} + * associated with it as the input. + * + * @param the type of the keys in the output {@code PCollection} + * @param the type of the elements in the input + * {@code PCollection} and the values in the output + * {@code PCollection} + */ +public class WithKeys extends PTransform, + PCollection>> { + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>}, where each of the + * values in the input {@code PCollection} has been paired with a + * key computed from the value by invoking the given + * {@code SerializableFunction}. + */ + public static WithKeys of(SerializableFunction fn) { + return new WithKeys<>(fn, null); + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>}, where each of the + * values in the input {@code PCollection} has been paired with the + * given key. + */ + @SuppressWarnings("unchecked") + public static WithKeys of(final K key) { + return new WithKeys<>( + new SerializableFunction() { + @Override + public K apply(V value) { + return key; + } + }, + (Class) (key == null ? null : key.getClass())); + } + + + ///////////////////////////////////////////////////////////////////////////// + + private SerializableFunction fn; + private transient Class keyClass; + + private WithKeys(SerializableFunction fn, Class keyClass) { + this.fn = fn; + this.keyClass = keyClass; + } + + @Override + public PCollection> apply(PCollection in) { + Coder keyCoder; + if (keyClass == null) { + keyCoder = getCoderRegistry().getDefaultOutputCoder(fn, in.getCoder()); + } else { + keyCoder = getCoderRegistry().getDefaultCoder(TypeToken.of(keyClass)); + } + PCollection> result = + in.apply(ParDo.named("AddKeys") + .of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(fn.apply(c.element()), + c.element())); + } + })); + if (keyCoder != null) { + // TODO: Remove when we can set the coder inference context. + result.setCoder(KvCoder.of(keyCoder, in.getCoder())); + } + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java new file mode 100644 index 000000000000..f91d7d2ca669 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java @@ -0,0 +1,367 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * A row result of a CoGroupByKey. This is a tuple of Iterables produced for + * a given key, and these can be accessed in different ways. + */ +public class CoGbkResult { + // TODO: If we keep this representation for any amount of time, + // optimize it so that the union tag does not have to be repeated in the + // values stored under the union tag key. + /** + * A map of integer union tags to a list of union objects. + * Note: the key and the embedded union tag are the same, so it is redundant + * to store it multiple times, but for now it makes encoding easier. + */ + private final Map> valueMap; + + private final CoGbkResultSchema schema; + + /** + * A row in the PCollection resulting from a CoGroupByKey transform. + * Currently, this row must fit into memory. + * + * @param schema the set of tuple tags used to refer to input tables and + * result values + * @param values the raw results from a group-by-key + */ + @SuppressWarnings("unchecked") + public CoGbkResult( + CoGbkResultSchema schema, + Iterable values) { + this.schema = schema; + valueMap = new TreeMap<>(); + for (RawUnionValue value : values) { + // Make sure the given union tag has a corresponding tuple tag in the + // schema. + int unionTag = value.getUnionTag(); + if (schema.size() <= unionTag) { + throw new IllegalStateException("union tag " + unionTag + + " has no corresponding tuple tag in the result schema"); + } + List taggedValueList = valueMap.get(unionTag); + if (taggedValueList == null) { + taggedValueList = new ArrayList<>(); + valueMap.put(unionTag, taggedValueList); + } + taggedValueList.add(value); + } + } + + public boolean isEmpty() { + return valueMap == null || valueMap.isEmpty(); + } + + /** + * Returns the schema used by this CoGbkResult. + */ + public CoGbkResultSchema getSchema() { + return schema; + } + + @Override + public String toString() { + return valueMap.toString(); + } + + /** + * Returns the values from the table represented by the given + * {@code TupleTag} as an {@code Iterable} (which may be empty if there + * are no results). + */ + public Iterable getAll(TupleTag tag) { + int index = schema.getIndex(tag); + if (index < 0) { + throw new IllegalArgumentException("TupleTag " + tag + + " is not in the schema"); + } + List unions = valueMap.get(index); + if (unions == null) { + return buildEmptyIterable(tag); + } + return new UnionValueIterable<>(unions); + } + + /** + * If there is a singleton value for the given tag, returns it. + * Otherwise, throws an IllegalArgumentException. + */ + public V getOnly(TupleTag tag) { + return innerGetOnly(tag, null, false); + } + + /** + * If there is a singleton value for the given tag, returns it. If there is + * no value for the given tag, returns the defaultValue. + * Otherwise, throws an IllegalArgumentException. + */ + public V getOnly(TupleTag tag, V defaultValue) { + return innerGetOnly(tag, defaultValue, true); + } + + /** + * A coder for CoGbkResults. + */ + public static class CoGbkResultCoder extends StandardCoder { + + private final CoGbkResultSchema schema; + private final MapCoder> mapCoder; + + /** + * Returns a CoGbkResultCoder for the given schema and unionCoder. + */ + public static CoGbkResultCoder of( + CoGbkResultSchema schema, + UnionCoder unionCoder) { + return new CoGbkResultCoder(schema, unionCoder); + } + + @JsonCreator + public static CoGbkResultCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components, + @JsonProperty(PropertyNames.CO_GBK_RESULT_SCHEMA) CoGbkResultSchema schema) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return new CoGbkResultCoder(schema, (MapCoder) components.get(0)); + } + + private CoGbkResultCoder( + CoGbkResultSchema tupleTags, + UnionCoder unionCoder) { + this.schema = tupleTags; + this.mapCoder = MapCoder.of(VarIntCoder.of(), + ListCoder.of(unionCoder)); + } + + private CoGbkResultCoder( + CoGbkResultSchema tupleTags, + MapCoder mapCoder) { + this.schema = tupleTags; + this.mapCoder = mapCoder; + } + + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return Arrays.>asList(mapCoder); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addObject(result, PropertyNames.CO_GBK_RESULT_SCHEMA, schema.asCloudObject()); + return result; + } + + @Override + public void encode( + CoGbkResult value, + OutputStream outStream, + Context context) throws CoderException, + IOException { + if (!schema.equals(value.getSchema())) { + throw new CoderException("input schema does not match coder schema"); + } + mapCoder.encode(value.valueMap, outStream, context); + } + + @Override + public CoGbkResult decode( + InputStream inStream, + Context context) + throws CoderException, IOException { + Map> map = mapCoder.decode( + inStream, context); + return new CoGbkResult(schema, map); + } + + public boolean equals(Object other) { + if (!super.equals(other)) { + return false; + } + return schema.equals(((CoGbkResultCoder) other).schema); + } + + @Override + public boolean isDeterministic() { + return mapCoder.isDeterministic(); + } + } + + + ////////////////////////////////////////////////////////////////////////////// + // Methods for testing purposes + + /** + * Returns a new CoGbkResult that contains just the given tag the given data. + */ + public static CoGbkResult of(TupleTag tag, List data) { + return CoGbkResult.empty().and(tag, data); + } + + /** + * Returns a new CoGbkResult based on this, with the given tag and given data + * added to it. + */ + public CoGbkResult and(TupleTag tag, List data) { + if (nextTestUnionId != schema.size()) { + throw new IllegalArgumentException( + "Attempting to call and() on a CoGbkResult apparently not created by" + + " of()."); + } + Map> valueMap = new TreeMap<>(this.valueMap); + valueMap.put(nextTestUnionId, + convertValueListToUnionList(nextTestUnionId, data)); + return new CoGbkResult( + new CoGbkResultSchema(schema.getTupleTagList().and(tag)), valueMap, + nextTestUnionId + 1); + } + + /** + * Returns an empty CoGbkResult. + */ + public static CoGbkResult empty() { + return new CoGbkResult(new CoGbkResultSchema(TupleTagList.empty()), + new TreeMap>()); + } + + ////////////////////////////////////////////////////////////////////////////// + + private int nextTestUnionId = 0; + + private CoGbkResult( + CoGbkResultSchema schema, + Map> valueMap, + int nextTestUnionId) { + this(schema, valueMap); + this.nextTestUnionId = nextTestUnionId; + } + + private CoGbkResult( + CoGbkResultSchema schema, + Map> valueMap) { + this.schema = schema; + this.valueMap = valueMap; + } + + private static List convertValueListToUnionList( + int unionTag, List data) { + List unionList = new ArrayList<>(); + for (V value : data) { + unionList.add(new RawUnionValue(unionTag, value)); + } + return unionList; + } + + private Iterable buildEmptyIterable(TupleTag tag) { + return new ArrayList<>(); + } + + private V innerGetOnly( + TupleTag tag, + V defaultValue, + boolean useDefault) { + int index = schema.getIndex(tag); + if (index < 0) { + throw new IllegalArgumentException("TupleTag " + tag + + " is not in the schema"); + } + List unions = valueMap.get(index); + if (unions.isEmpty()) { + if (useDefault) { + return defaultValue; + } else { + throw new IllegalArgumentException("TupleTag " + tag + + " corresponds to an empty result, and no default was provided"); + } + } + if (unions.size() != 1) { + throw new IllegalArgumentException("TupleTag " + tag + + " corresponds to a non-singleton result of size " + unions.size()); + } + return (V) unions.get(0).getValue(); + } + + /** + * Lazily converts and recasts an {@code Iterable} into an + * {@code Iterable}, where V is the type of the raw union value's contents. + */ + private static class UnionValueIterable implements Iterable { + + private final Iterable unions; + + private UnionValueIterable(Iterable unions) { + this.unions = unions; + } + + @Override + public Iterator iterator() { + final Iterator unionsIterator = unions.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return unionsIterator.hasNext(); + } + + @Override + public V next() { + return (V) unionsIterator.next().getValue(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java new file mode 100644 index 000000000000..93883b80750c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * A schema for the results of a CoGroupByKey. This maintains the full + * set of TupleTags for the results of a CoGroupByKey, and facilitates mapping + * between TupleTags and Union Tags (which are used as secondary keys in the + * CoGroupByKey). + */ +class CoGbkResultSchema implements Serializable { + + private final TupleTagList tupleTagList; + + @JsonCreator + public static CoGbkResultSchema of( + @JsonProperty(PropertyNames.TUPLE_TAGS) List> tags) { + TupleTagList tupleTags = TupleTagList.empty(); + for (TupleTag tag : tags) { + tupleTags = tupleTags.and(tag); + } + return new CoGbkResultSchema(tupleTags); + } + + /** + * Maps TupleTags to union tags. This avoids needing to encode the tags + * themselves. + */ + private final HashMap, Integer> tagMap = new HashMap<>(); + + /** + * Builds a schema from a tuple of {@code TupleTag}s. + */ + public CoGbkResultSchema(TupleTagList tupleTagList) { + this.tupleTagList = tupleTagList; + int index = -1; + for (TupleTag tag : tupleTagList.getAll()) { + index++; + tagMap.put(tag, index); + } + } + + /** + * Returns the index for the given tuple tag, if the tag is present in this + * schema, -1 if it isn't. + */ + public int getIndex(TupleTag tag) { + Integer index = tagMap.get(tag); + return index == null ? -1 : index; + } + + /** + * Returns the JoinTupleTag at the given index. + */ + public TupleTag getTag(int index) { + return tupleTagList.get(index); + } + + /** + * Returns the number of columms for this schema. + */ + public int size() { + return tupleTagList.getAll().size(); + } + + /** + * Returns the TupleTagList tuple associated with this schema. + */ + public TupleTagList getTupleTagList() { + return tupleTagList; + } + + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + List serializedTags = new ArrayList<>(tupleTagList.size()); + for (TupleTag tag : tupleTagList.getAll()) { + serializedTags.add(tag.asCloudObject()); + } + addList(result, PropertyNames.TUPLE_TAGS, serializedTags); + return result; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof CoGbkResultSchema)) { + return false; + } + CoGbkResultSchema other = (CoGbkResultSchema) obj; + return tupleTagList.getAll().equals(other.tupleTagList.getAll()); + } + + @Override + public int hashCode() { + return tupleTagList.getAll().hashCode(); + } + + @Override + public String toString() { + return "CoGbkResultSchema: " + tupleTagList.getAll(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java new file mode 100644 index 000000000000..d81c9ef707ca --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple.TaggedKeyedPCollection; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import java.util.ArrayList; +import java.util.List; + +/** + * A transform that performs a CoGroupByKey on a tuple of tables. A + * CoGroupByKey groups results from all tables by like keys into CoGbkResults, + * from which the results for any specific table can be accessed by the + * TupleTag supplied with the initial table. + * + *

Example of performing a CoGroupByKey followed by a ParDo that consumes + * the results: + *

 
+ * {@literal PCollection>} pt1 = ...;
+ * {@literal PCollection>} pt2 = ...;
+ *
+ * final {@literal TupleTag} t1 = new {@literal TupleTag<>()};
+ * final {@literal TupleTag} t2 = new {@literal TupleTag<>()};
+ * {@literal PCollection>} coGbkResultCollection =
+ *   KeyedPCollectionTuple.of(t1, pt1)
+ *                        .and(t2, pt2)
+ *                        .apply({@literal CoGroupByKey.create()});
+ *
+ * {@literal PCollection} finalResultCollection =
+ *   coGbkResultCollection.apply(ParDo.of(
+ *     new {@literal DoFn, T>()} {
+ *       {@literal @}Override
+ *       public void processElement(ProcessContext c) {
+ *         {@literal KV} e = c.element();
+ *         {@literal Iterable} pt1Vals = e.getValue().getAll(t1);
+ *         V2 pt2Val = e.getValue().getOnly(t2);
+ *          ... Do Something ....
+ *         c.output(...some T...);
+ *       }
+ *     }));
+ *  
+ * + * @param the type of the keys in the input and output + * {@code PCollection}s + */ +public class CoGroupByKey extends + PTransform, + PCollection>> { + /** + * Returns a {@code CoGroupByKey} {@code PTransform}. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + */ + public static CoGroupByKey create() { + return new CoGroupByKey<>(); + } + + private CoGroupByKey() { } + + @Override + public PCollection> apply( + KeyedPCollectionTuple input) { + if (input.isEmpty()) { + throw new IllegalArgumentException( + "must have at least one input to a KeyedPCollections"); + } + + // First build the union coder. + // TODO: Look at better integration of union types with the + // schema specified in the input. + List> codersList = new ArrayList<>(); + for (TaggedKeyedPCollection entry : input.getKeyedCollections()) { + codersList.add(getValueCoder(entry.pCollection)); + } + UnionCoder unionCoder = UnionCoder.of(codersList); + Coder keyCoder = input.getKeyCoder(); + KvCoder kVCoder = + KvCoder.of(keyCoder, unionCoder); + + PCollectionList> unionTables = + PCollectionList.empty(getPipeline()); + + // TODO: Use the schema to order the indices rather than depending + // on the fact that the schema ordering is identical to the ordering from + // input.getJoinCollections(). + int index = -1; + for (TaggedKeyedPCollection entry : input.getKeyedCollections()) { + index++; + PCollection> unionTable = + makeUnionTable(index, entry.pCollection, kVCoder); + unionTables = unionTables.and(unionTable); + } + + PCollection> flattenedTable = + unionTables.apply(Flatten.>create()); + + PCollection>> groupedTable = + flattenedTable.apply(GroupByKey.create()); + + CoGbkResultSchema tupleTags = input.getCoGbkResultSchema(); + PCollection> result = groupedTable.apply( + ParDo.of(new ConstructCoGbkResultFn(tupleTags)) + .named("ConstructCoGbkResultFn")); + result.setCoder(KvCoder.of(keyCoder, + CoGbkResultCoder.of(tupleTags, unionCoder))); + + return result; + } + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the value coder for the given PCollection. Assumes that the value + * coder is an instance of {@code KvCoder}. + */ + private Coder getValueCoder(PCollection> pCollection) { + // Assumes that the PCollection uses a KvCoder. + Coder entryCoder = pCollection.getCoder(); + if (!(entryCoder instanceof KvCoder)) { + throw new IllegalArgumentException("PCollection does not use a KvCoder"); + } + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) entryCoder; + return coder.getValueCoder(); + } + + /** + * Returns a UnionTable for the given input PCollection, using the given + * union index and the given unionTableEncoder. + */ + private PCollection> makeUnionTable( + final int index, + PCollection> pCollection, + KvCoder unionTableEncoder) { + + return pCollection.apply(ParDo.of( + new ConstructUnionTableFn(index)).named("MakeUnionTable")) + .setCoder(unionTableEncoder); + } + + /** + * A DoFn to construct a UnionTable (i.e., a + * {@code PCollection>} from a + * {@code PCollection>}. + */ + private static class ConstructUnionTableFn extends + DoFn, KV> { + + private final int index; + + public ConstructUnionTableFn(int index) { + this.index = index; + } + + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + c.output(KV.of(e.getKey(), new RawUnionValue(index, e.getValue()))); + } + } + + /** + * A DoFn to construct a CoGbkResult from an input grouped union + * table. + */ + private static class ConstructCoGbkResultFn + extends DoFn>, + KV> { + + private final CoGbkResultSchema schema; + + public ConstructCoGbkResultFn(CoGbkResultSchema schema) { + this.schema = schema; + } + + @Override + public void processElement(ProcessContext c) { + KV> e = c.element(); + c.output(KV.of(e.getKey(), new CoGbkResult(schema, e.getValue()))); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java new file mode 100644 index 000000000000..a9fd4b684f85 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Represents an immutable tuple of keyed PCollections (i.e. PCollections of + * {@code KV}), with key type K. + * + * @param the type of key shared by all constituent PCollections + */ +public class KeyedPCollectionTuple implements PInput { + /** + * Returns an empty {@code KeyedPCollections} on the given pipeline. + */ + public static KeyedPCollectionTuple empty(Pipeline pipeline) { + return new KeyedPCollectionTuple<>(pipeline); + } + + /** + * Returns a new {@code KeyedPCollections} with the given tag and initial + * PCollection. + */ + public static KeyedPCollectionTuple of( + TupleTag tag, + PCollection> pc) { + return new KeyedPCollectionTuple(pc.getPipeline()).and(tag, pc); + } + + /** + * Returns a new {@code KeyedPCollections} that is the same as this, + * appended with the given PCollection. + */ + public KeyedPCollectionTuple and( + TupleTag< V> tag, + PCollection> pc) { + if (pc.getPipeline() != getPipeline()) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + TaggedKeyedPCollection wrapper = + new TaggedKeyedPCollection<>(tag, pc); + Coder myKeyCoder = keyCoder == null ? getKeyCoder(pc) : keyCoder; + List> + newKeyedCollections = + copyAddLast( + keyedCollections, + wrapper); + return new KeyedPCollectionTuple<>( + getPipeline(), + newKeyedCollections, + schema.getTupleTagList().and(tag), + myKeyCoder); + } + + public boolean isEmpty() { + return keyedCollections.isEmpty(); + } + + /** + * Returns a list of TaggedKeyedPCollections for the PCollections contained in + * this {@code KeyedPCollections}. + */ + public List> getKeyedCollections() { + return keyedCollections; + } + + /** + * Applies the given transform to this input. + */ + public O apply( + PTransform, O> transform) { + return Pipeline.applyTransform(this, transform); + } + + /** + * Expands the component PCollections, stripping off any tag-specific + * information. + */ + @Override + public Collection expand() { + List> retval = new ArrayList<>(); + for (TaggedKeyedPCollection taggedPCollection : keyedCollections) { + retval.add(taggedPCollection.pCollection); + } + return retval; + } + + /** + * Returns the KeyCoder for all PCollections in this KeyedPCollections. + */ + public Coder getKeyCoder() { + if (keyCoder == null) { + throw new IllegalStateException("cannot return null keyCoder"); + } + return keyCoder; + } + + /** + * Returns the CoGbkResultSchema associated with this + * KeyedPCollections. + */ + public CoGbkResultSchema getCoGbkResultSchema() { + return schema; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public void finishSpecifying() { + for (TaggedKeyedPCollection taggedPCollection : keyedCollections) { + taggedPCollection.pCollection.finishSpecifying(); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A utility class to help ensure coherence of tag and input PCollection + * types. + */ + static class TaggedKeyedPCollection { + final TupleTag tupleTag; + final PCollection> pCollection; + + public TaggedKeyedPCollection( + TupleTag tupleTag, + PCollection> pCollection) { + this.tupleTag = tupleTag; + this.pCollection = pCollection; + } + } + + /** + * We use a List to properly track the order in which collections are added. + */ + private final List> keyedCollections; + + private final Coder keyCoder; + + private final CoGbkResultSchema schema; + + private final Pipeline pipeline; + + KeyedPCollectionTuple(Pipeline pipeline) { + this(pipeline, + new ArrayList>(), + TupleTagList.empty(), + null); + } + + KeyedPCollectionTuple( + Pipeline pipeline, + List> keyedCollections, + TupleTagList tupleTagList, + Coder keyCoder) { + this.pipeline = pipeline; + this.keyedCollections = keyedCollections; + this.schema = new CoGbkResultSchema(tupleTagList); + this.keyCoder = keyCoder; + } + + private static Coder getKeyCoder(PCollection> pc) { + // Need to run coder inference on this PCollection before inspecting it. + pc.finishSpecifying(); + + // Assumes that the PCollection uses a KvCoder. + Coder entryCoder = pc.getCoder(); + if (!(entryCoder instanceof KvCoder)) { + throw new IllegalArgumentException("PCollection does not use a KvCoder"); + } + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) entryCoder; + return coder.getKeyCoder(); + } + + private static List> copyAddLast( + List> keyedCollections, + TaggedKeyedPCollection taggedCollection) { + List> retval = + new ArrayList<>(keyedCollections); + retval.add(taggedCollection); + return retval; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java new file mode 100644 index 000000000000..b52f8b3e49c2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +// TODO: Think about making this a complete dynamic union by adding +// a schema. Type would then be defined by the corresponding schema entry. + +/** + * This corresponds to an integer union tag and value. The mapping of + * union tag to type must come from elsewhere. + */ +class RawUnionValue { + private final int unionTag; + private final Object value; + + /** + * Constructs a partial union from the given union tag and value. + */ + public RawUnionValue(int unionTag, Object value) { + this.unionTag = unionTag; + this.value = value; + } + + public int getUnionTag() { + return unionTag; + } + + public Object getValue() { + return value; + } + + @Override + public String toString() { + return unionTag + ":" + value; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java new file mode 100644 index 000000000000..a6bb4bcb4586 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A UnionCoder encodes RawUnionValues. + */ +class UnionCoder extends StandardCoder { + // TODO: Think about how to integrate this with a schema object (i.e. + // a tuple of tuple tags). + /** + * Builds a union coder with the given list of element coders. This list + * corresponds to a mapping of union tag to Coder. Union tags start at 0. + */ + public static UnionCoder of(List> elementCoders) { + return new UnionCoder(elementCoders); + } + + @JsonCreator + public static UnionCoder jsonOf( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> elements) { + return UnionCoder.of(elements); + } + + private int getIndexForEncoding(RawUnionValue union) { + if (union == null) { + throw new IllegalArgumentException("cannot encode a null tagged union"); + } + int index = union.getUnionTag(); + if (index < 0 || index >= elementCoders.size()) { + throw new IllegalArgumentException( + "union value index " + index + " not in range [0.." + + (elementCoders.size() - 1) + "]"); + } + return index; + } + + @SuppressWarnings("unchecked") + @Override + public void encode( + RawUnionValue union, + OutputStream outStream, + Context context) + throws IOException, CoderException { + int index = getIndexForEncoding(union); + // Write out the union tag. + VarInt.encode(index, outStream); + + // Write out the actual value. + Coder coder = (Coder) elementCoders.get(index); + coder.encode( + union.getValue(), + outStream, + context); + } + + @Override + public RawUnionValue decode(InputStream inStream, Context context) + throws IOException, CoderException { + int index = VarInt.decodeInt(inStream); + Object value = elementCoders.get(index).decode(inStream, context); + return new RawUnionValue(index, value); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return elementCoders; + } + + /** + * Since this coder uses elementCoders.get(index) and coders that are known to run in constant + * time, we defer the return value to that coder. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(RawUnionValue union, Context context) { + int index = getIndexForEncoding(union); + Coder coder = (Coder) elementCoders.get(index); + return coder.isRegisterByteSizeObserverCheap(union.getValue(), context); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + RawUnionValue union, ElementByteSizeObserver observer, Context context) + throws Exception { + int index = getIndexForEncoding(union); + // Write out the union tag. + observer.update(VarInt.getLength(index)); + // Write out the actual value. + Coder coder = (Coder) elementCoders.get(index); + coder.registerByteSizeObserver(union.getValue(), observer, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + private final List> elementCoders; + + private UnionCoder(List> elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public boolean isDeterministic() { + for (Coder elementCoder : elementCoders) { + if (!elementCoder.isDeterministic()) { + return false; + } + } + + return true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java new file mode 100644 index 000000000000..ba907ac2cd73 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines the {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey} transform + * for joining multiple PCollections. + */ +package com.google.cloud.dataflow.sdk.transforms.join; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java new file mode 100644 index 000000000000..b72e90e780ac --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s for transforming + * data in a pipeline. + * + *

A {@link com.google.cloud.dataflow.sdk.transforms.PTransform} is an operation that takes an + * {@code Input} (some subtype of {@link com.google.cloud.dataflow.sdk.values.PInput}) + * and produces an + * {@code Output} (some subtype of {@link com.google.cloud.dataflow.sdk.values.POutput}). + * + *

Common PTransforms include root PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read} and + * {@link com.google.cloud.dataflow.sdk.transforms.Create}, processing and + * conversion operations like {@link com.google.cloud.dataflow.sdk.transforms.ParDo}, + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.Combine}, and + * {@link com.google.cloud.dataflow.sdk.transforms.Count}, and outputting + * PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write}. + * + *

New PTransforms can be created by composing existing PTransforms. + * Most PTransforms in this package are composites, and users can also create composite PTransforms + * for their own application-specific logic. + * + */ +package com.google.cloud.dataflow.sdk.transforms; + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java new file mode 100644 index 000000000000..01de83f1585d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import org.joda.time.Instant; + +/** + * A {@code BoundedWindow} represents a finite grouping of elements, with an + * upper bound (larger timestamps represent more recent data) on the timestamps + * of elements that can be placed in the window. This finiteness means that for + * every window, at some point in time, all data for that window will have + * arrived and can be processed together. + * + *

Windows must also implement {@link Object#equals} and + * {@link Object#hashCode} such that windows that are logically equal will + * be treated as equal by {@code equals()} and {@code hashCode()}. + */ +public abstract class BoundedWindow { + /** + * Returns the upper bound of timestamps for values in this window. + */ + public abstract Instant maxTimestamp(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java new file mode 100644 index 000000000000..bb0de796f86a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java @@ -0,0 +1,300 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.Days; +import org.joda.time.Instant; +import org.joda.time.Months; +import org.joda.time.Years; + +/** + * A collection of {@link WindowingFn}s that windows values into calendar-based + * windows such as spans of days, months, or years. + * + *

For example, to group data into quarters that change on the 15th, use + * {@code CalendarWindows.months(3).withStartingMonth(2014, 1).beginningOnDay(15)}. + */ +public class CalendarWindows { + + /** + * Returns a {@link WindowingFn} that windows elements into periods measured by days. + * + *

For example, {@code CalendarWindows.days(1)} will window elements into + * separate windows for each day. + */ + public static DaysWindows days(int number) { + return new DaysWindows(number, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowingFn} that windows elements into periods measured by weeks. + * + *

For example, {@code CalendarWindows.weeks(1, DateTimeConstants.TUESDAY)} will + * window elements into week-long windows starting on Tuesdays. + */ + public static DaysWindows weeks(int number, int startDayOfWeek) { + return new DaysWindows( + 7 * number, + new DateTime(0, DateTimeZone.UTC).withDayOfWeek(startDayOfWeek), + DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowingFn} that windows elements into periods measured by months. + * + *

For example, + * {@code CalendarWindows.months(8).withStartingMonth(2014, 1).beginningOnDay(10)} + * will window elements into 8 month windows where that start on the 10th day of month, + * and the first window begins in January 2014. + */ + public static MonthsWindows months(int number) { + return new MonthsWindows(number, 1, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowingFn} that windows elements into periods measured by years. + * + *

For example, + * {@code CalendarWindows.years(1).withTimeZone(DateTimeZone.forId("America/Los_Angeles"))} + * will window elements into year-long windows that start at midnight on Jan 1, in the + * America/Los_Angeles time zone. + */ + public static YearsWindows years(int number) { + return new YearsWindows(number, 1, 1, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * A {@link WindowingFn} that windows elements into periods measured by days. + * + *

By default, periods of multiple days are measured starting at the + * epoch. This can be overridden with {@link #withStartingDay}. + * + *

The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class DaysWindows extends PartitioningWindowingFn { + + public DaysWindows withStartingDay(int year, int month, int day) { + return new DaysWindows( + number, new DateTime(year, month, day, 0, 0, timeZone), timeZone); + } + + public DaysWindows withTimeZone(DateTimeZone timeZone) { + return new DaysWindows( + number, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private DateTime startDate; + private DateTimeZone timeZone; + + private DaysWindows(int number, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + int dayOffset = Days.daysBetween(startDate, datetime).getDays() / number * number; + + DateTime begin = startDate.plusDays(dayOffset); + DateTime end = begin.plusDays(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowingFn other) { + if (!(other instanceof DaysWindows)) { + return false; + } + DaysWindows that = (DaysWindows) other; + return number == that.number + && startDate == that.startDate + && timeZone == that.timeZone; + } + } + + /** + * A {@link WindowingFn} that windows elements into periods measured by months. + * + *

By default, periods of multiple months are measured starting at the + * epoch. This can be overridden with {@link #withStartingMonth}. + * + *

Months start on the first day of each calendar month, unless overridden by + * {@link #beginningOnDay}. + * + *

The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class MonthsWindows extends PartitioningWindowingFn { + + public MonthsWindows beginningOnDay(int dayOfMonth) { + return new MonthsWindows( + number, dayOfMonth, startDate, timeZone); + } + + public MonthsWindows withStartingMonth(int year, int month) { + return new MonthsWindows( + number, dayOfMonth, new DateTime(year, month, 1, 0, 0, timeZone), timeZone); + } + + public MonthsWindows withTimeZone(DateTimeZone timeZone) { + return new MonthsWindows( + number, dayOfMonth, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private int dayOfMonth; + private DateTime startDate; + private DateTimeZone timeZone; + + private MonthsWindows(int number, int dayOfMonth, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.dayOfMonth = dayOfMonth; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + int monthOffset = + Months.monthsBetween(startDate.withDayOfMonth(dayOfMonth), datetime).getMonths() + / number * number; + + DateTime begin = startDate.withDayOfMonth(dayOfMonth).plusMonths(monthOffset); + DateTime end = begin.plusMonths(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowingFn other) { + if (!(other instanceof MonthsWindows)) { + return false; + } + MonthsWindows that = (MonthsWindows) other; + return number == that.number + && dayOfMonth == dayOfMonth + && startDate == that.startDate + && timeZone == that.timeZone; + } + } + + /** + * A {@link WindowingFn} that windows elements into periods measured by years. + * + *

By default, periods of multiple years are measured starting at the + * epoch. This can be overridden with {@link #withStartingYear}. + * + *

Years start on the first day of each calendar year, unless overridden by + * {@link #beginningOnDay}. + * + *

The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class YearsWindows extends PartitioningWindowingFn { + + public YearsWindows beginningOnDay(int monthOfYear, int dayOfMonth) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, startDate, timeZone); + } + + public YearsWindows withStartingYear(int year) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, new DateTime(year, 1, 1, 0, 0, timeZone), timeZone); + } + + public YearsWindows withTimeZone(DateTimeZone timeZone) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private int monthOfYear; + private int dayOfMonth; + private DateTime startDate; + private DateTimeZone timeZone; + + private YearsWindows( + int number, int monthOfYear, int dayOfMonth, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.monthOfYear = monthOfYear; + this.dayOfMonth = dayOfMonth; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + DateTime offsetStart = startDate.withMonthOfYear(monthOfYear).withDayOfMonth(dayOfMonth); + + int yearOffset = + Years.yearsBetween(offsetStart, datetime).getYears() / number * number; + + DateTime begin = offsetStart.plusYears(yearOffset); + DateTime end = begin.plusYears(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowingFn other) { + if (!(other instanceof YearsWindows)) { + return false; + } + YearsWindows that = (YearsWindows) other; + return number == that.number + && monthOfYear == monthOfYear + && dayOfMonth == dayOfMonth + && startDate == that.startDate + && timeZone == that.timeZone; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java new file mode 100644 index 000000000000..ea7a22c8fc41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** + * A {@link WindowingFn} that windows values into fixed-size timestamp-based windows. + * + *

For example, in order to partition the data into 10 minute windows: + *

 {@code
+ * PCollection items = ...;
+ * PCollection windowedItems = items.apply(
+ *   Window.by(FixedWindows.of(Duration.standardMinutes(10))));
+ * } 
+ */ +public class FixedWindows extends PartitioningWindowingFn { + + /** + * Size of this window. + */ + private final Duration size; + + /** + * Offset of this window. Windows start at time + * N * size + offset, where 0 is the epoch. + */ + private final Duration offset; + + /** + * Partitions the timestamp space into half-open intervals of the form + * [N * size, (N + 1) * size), where 0 is the epoch. + */ + public static FixedWindows of(Duration size) { + return new FixedWindows(size, Duration.ZERO); + } + + /** + * Partitions the timestamp space into half-open intervals of the form + * [N * size + offset, (N + 1) * size + offset), + * where 0 is the epoch. + * + * @throws IllegalAgumentException if offset is not in [0, size) + */ + public FixedWindows withOffset(Duration offset) { + return new FixedWindows(size, offset); + } + + private FixedWindows(Duration size, Duration offset) { + if (offset.isShorterThan(Duration.ZERO) || !offset.isShorterThan(size)) { + throw new IllegalArgumentException( + "FixedWindows WindowingStrategies must have 0 <= offset < size"); + } + this.size = size; + this.offset = offset; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + long start = timestamp.getMillis() + - timestamp.plus(size).minus(offset).getMillis() % size.getMillis(); + return new IntervalWindow(new Instant(start), size); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getFixedSizeCoder(size); + } + + @Override + public boolean isCompatible(WindowingFn other) { + return (other instanceof FixedWindows) + && (size.equals(((FixedWindows) other).size)) + && (offset.equals(((FixedWindows) other).offset)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java new file mode 100644 index 000000000000..bfcb9c7fa159 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Instant; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collection; + +/** + * Default {@link WindowingFn} where all data is in the same bucket. + */ +public class GlobalWindow + extends NonMergingWindowingFn { + @Override + public Collection assignWindows(AssignContext c) { + return Arrays.asList(Window.INSTANCE); + } + + @Override + public boolean isCompatible(WindowingFn o) { + return o instanceof GlobalWindow; + } + + @Override + public Coder windowCoder() { + return Window.Coder.INSTANCE; + } + + /** + * The default window into which all data is placed. + */ + public static class Window extends BoundedWindow { + public static Window INSTANCE = new Window(); + + @Override + public Instant maxTimestamp() { + return new Instant(Long.MAX_VALUE); + } + + private Window() {} + + /** + * {@link Coder} for encoding and decoding {@code Window}s. + */ + public static class Coder extends AtomicCoder { + public static Coder INSTANCE = new Coder(); + + @Override + public void encode(Window window, OutputStream outStream, Context context) {} + + @Override + public Window decode(InputStream inStream, Context context) { + return Window.INSTANCE; + } + + @Override + public boolean isDeterministic() { + return true; + } + + private Coder() {} + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java new file mode 100644 index 000000000000..8ac23501c97e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java @@ -0,0 +1,257 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * An implementation of {@link BoundedWindow} that represents an interval from + * {@link #start} (inclusive) to {@link #end} (exclusive). + */ +public class IntervalWindow extends BoundedWindow + implements Comparable { + /** + * Start of the interval, inclusive. + */ + private final Instant start; + + /** + * End of the interval, exclusive. + */ + private final Instant end; + + /** + * Creates a new IntervalWindow that represents the half-open time + * interval [start, end). + */ + public IntervalWindow(Instant start, Instant end) { + this.start = start; + this.end = end; + } + + public IntervalWindow(Instant start, Duration size) { + this.start = start; + this.end = start.plus(size); + } + + /** + * Returns the start of this window, inclusive. + */ + public Instant start() { + return start; + } + + /** + * Returns the end of this window, exclusive. + */ + public Instant end() { + return end; + } + + /** + * Returns the largest timestamp that can be included in this window. + */ + @Override + public Instant maxTimestamp() { + // end not inclusive + return end.minus(1); + } + + /** + * Returns whether this window contains the given window. + */ + public boolean contains(IntervalWindow other) { + return !this.start.isAfter(other.start) && !this.end.isBefore(other.end); + } + + /** + * Returns whether this window is disjoint from the given window. + */ + public boolean isDisjoint(IntervalWindow other) { + return !this.end.isAfter(other.start) || !other.end.isAfter(this.start); + } + + /** + * Returns whether this window intersects the given window. + */ + public boolean intersects(IntervalWindow other) { + return !isDisjoint(other); + } + + /** + * Returns the minimal window that includes both this window and + * the given window. + */ + public IntervalWindow span(IntervalWindow other) { + return new IntervalWindow( + new Instant(Math.min(start.getMillis(), other.start.getMillis())), + new Instant(Math.max(end.getMillis(), other.end.getMillis()))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof IntervalWindow) + && ((IntervalWindow) o).end.isEqual(end) + && ((IntervalWindow) o).start.isEqual(start); + } + + @Override + public int hashCode() { + // The end values are themselves likely to be arithmetic sequence, + // which is a poor distribution to use for a hashtable, so we + // add a highly non-linear transformation. + return (int) + (start.getMillis() + modInverse((int) (end.getMillis() << 1) + 1)); + } + + /** + * Compute the inverse of (odd) x mod 2^32. + */ + private int modInverse(int x) { + // Cube gives inverse mod 2^4, as x^4 == 1 (mod 2^4) for all odd x. + int inverse = x * x * x; + // Newton iteration doubles correct bits at each step. + inverse *= 2 - x * inverse; + inverse *= 2 - x * inverse; + inverse *= 2 - x * inverse; + return inverse; + } + + @Override + public String toString() { + return "[" + start + ".." + end + ")"; + } + + @Override + public int compareTo(IntervalWindow o) { + if (start.isEqual(o.start)) { + return end.compareTo(o.end); + } + return start.compareTo(o.start); + } + + /** + * Returns a Coder suitable for encoding IntervalWindows. + */ + public static Coder getCoder() { + return IntervalWindowCoder.of(); + } + + /** + * Returns a Coder for encoding interval windows of fixed size (which + * is more efficient than {@link #getCoder()} as it only needs to encode + * one endpoint). + */ + public static Coder getFixedSizeCoder(final Duration size) { + return FixedSizeIntervalWindowCoder.of(size); + } + + private static class IntervalWindowCoder extends AtomicCoder { + private static final IntervalWindowCoder INSTANCE = + new IntervalWindowCoder(); + private static final Coder instantCoder = InstantCoder.of(); + + @JsonCreator + public static IntervalWindowCoder of() { + return INSTANCE; + } + + @Override + public void encode(IntervalWindow window, + OutputStream outStream, + Context context) + throws IOException, CoderException { + instantCoder.encode(window.start, outStream, context.nested()); + instantCoder.encode(window.end, outStream, context.nested()); + } + + @Override + public IntervalWindow decode(InputStream inStream, Context context) + throws IOException, CoderException { + Instant start = instantCoder.decode(inStream, context.nested()); + Instant end = instantCoder.decode(inStream, context.nested()); + return new IntervalWindow(start, end); + } + + @Override + public boolean isDeterministic() { return true; } + } + + private static class FixedSizeIntervalWindowCoder + extends AtomicCoder { + private static final Coder instantCoder = InstantCoder.of(); + + private final Duration size; + + @JsonCreator + public static FixedSizeIntervalWindowCoder of( + @JsonProperty("duration") String duration) { + return of(fromCloudDuration(duration)); + } + + public static FixedSizeIntervalWindowCoder of(Duration size) { + return new FixedSizeIntervalWindowCoder(size); + } + + private FixedSizeIntervalWindowCoder(Duration size) { + this.size = size; + } + + @Override + public void encode(IntervalWindow window, + OutputStream outStream, + Context context) + throws IOException, CoderException { + instantCoder.encode(window.start, outStream, context); + } + + @Override + public IntervalWindow decode(InputStream inStream, Context context) + throws IOException, CoderException { + Instant start = instantCoder.decode(inStream, context); + return new IntervalWindow(start, size); + } + + @Override + public boolean isDeterministic() { return true; } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addString(result, "duration", toCloudDuration(Duration.millis(size.getMillis()))); + return result; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindowingFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindowingFn.java new file mode 100644 index 000000000000..7ad7f29f6655 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindowingFn.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import java.util.Collection; + +/** + * A {@link WindowingFn} that represents an invalid pipeline state. + * + * @param window type + */ +public class InvalidWindowingFn extends WindowingFn { + private String cause; + private WindowingFn originalWindowingFn; + + public InvalidWindowingFn(String cause, WindowingFn originalWindowingFn) { + this.originalWindowingFn = originalWindowingFn; + this.cause = cause; + } + + /** + * Returns the reason that this {@code WindowingFn} is invalid. + */ + public String getCause() { + return cause; + } + + /** + * Returns the original windowingFn that this InvalidWindowingFn replaced. + */ + public WindowingFn getOriginalWindowingFn() { + return originalWindowingFn; + } + + @Override + public Collection assignWindows(AssignContext c) { + throw new UnsupportedOperationException(); + } + + @Override + public void mergeWindows(MergeContext c) { + throw new UnsupportedOperationException(); + } + + @Override + public Coder windowCoder() { + return originalWindowingFn.windowCoder(); + } + + /** + * {@code InvalidWindowingFn} objects with the same {@code originalWindowingFn} are compatible. + */ + @Override + public boolean isCompatible(WindowingFn other) { + return getClass() == other.getClass() + && getOriginalWindowingFn().isCompatible( + ((InvalidWindowingFn) other).getOriginalWindowingFn()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java new file mode 100644 index 000000000000..4d4dd8492684 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * A {@link WindowingFn} that merges overlapping {@link IntervalWindow}s. + */ +public class MergeOverlappingIntervalWindows { + + /** + * Merge overlapping intervals. + */ + public static void mergeWindows(WindowingFn.MergeContext c) throws Exception { + // Merge any overlapping windows into a single window. + // Sort the list of existing windows so we only have to + // traverse the list once rather than considering all + // O(n^2) window pairs. + List sortedWindows = new ArrayList<>(); + for (IntervalWindow window : c.windows()) { + sortedWindows.add(window); + } + Collections.sort(sortedWindows); + List merges = new ArrayList<>(); + MergeCandidate current = new MergeCandidate(); + for (IntervalWindow window : sortedWindows) { + if (current.intersects(window)) { + current.add(window); + } else { + merges.add(current); + current = new MergeCandidate(window); + } + } + merges.add(current); + for (MergeCandidate merge : merges) { + merge.apply(c); + } + } + + private static class MergeCandidate { + private IntervalWindow union; + private final List parts; + public MergeCandidate() { + parts = new ArrayList<>(); + } + public MergeCandidate(IntervalWindow window) { + union = window; + parts = new ArrayList<>(Arrays.asList(window)); + } + public boolean intersects(IntervalWindow window) { + return union == null || union.intersects(window); + } + public void add(IntervalWindow window) { + union = union == null ? window : union.span(window); + parts.add(window); + } + public void apply(WindowingFn.MergeContext c) throws Exception { + if (parts.size() > 1) { + c.merge(parts, union); + } + } + + @Override + public String toString() { + return "MergeCandidate[union=" + union + ", parts=" + parts + "]"; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowingFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowingFn.java new file mode 100644 index 000000000000..ffeea996d60d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowingFn.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +/** + * Abstract base class for {@link WindowingFn}s that do not merge windows. + * + * @param type of elements being windowed + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code WindowingFn} + */ +public abstract class NonMergingWindowingFn + extends WindowingFn { + + @Override + public final void mergeWindows(MergeContext c) { } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowingFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowingFn.java new file mode 100644 index 000000000000..6a65ba134f18 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowingFn.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.Collection; + +/** + * A {@link WindowingFn} that places each value into exactly one window + * based on its timestamp and never merges windows. + * + * @param type of elements being windowed + * @param window type + */ +public abstract class PartitioningWindowingFn + extends NonMergingWindowingFn { + /** + * Returns the single window to which elements with this timestamp belong. + */ + public abstract W assignWindow(Instant timestamp); + + @Override + public final Collection assignWindows(AssignContext c) { + return Arrays.asList(assignWindow(c.timestamp())); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java new file mode 100644 index 000000000000..47f8a0800583 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; + +import java.util.Arrays; +import java.util.Collection; + +/** + * A WindowingFn windowing values into sessions separated by {@link #gapDuration}-long + * periods with no elements. + * + *

For example, in order to window data into session with at least 10 minute + * gaps in between them: + *

 {@code
+ * PCollection pc = ...;
+ * PCollection windowed_pc = pc.apply(
+ *   Window.by(Sessions.withGapDuration(Duration.standardMinutes(10))));
+ * } 
+ */ +public class Sessions extends WindowingFn { + + /** + * Duration of the gaps between sessions. + */ + private final Duration gapDuration; + + /** + * Creates a {@code Sessions} {@link WindowingFn} with the specified gap duration. + */ + public static Sessions withGapDuration(Duration gapDuration) { + return new Sessions(gapDuration); + } + + /** + * Creates a {@code Sessions} {@link WindowingFn} with the specified gap duration. + */ + private Sessions(Duration gapDuration) { + this.gapDuration = gapDuration; + } + + @Override + public Collection assignWindows(AssignContext c) { + // Assign each element into a window from its timestamp until gapDuration in the + // future. Overlapping windows (representing elements within gapDuration of + // each other) will be merged. + return Arrays.asList(new IntervalWindow(c.timestamp(), gapDuration)); + } + + @Override + public void mergeWindows(MergeContext c) throws Exception { + MergeOverlappingIntervalWindows.mergeWindows(c); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowingFn other) { + return other instanceof Sessions; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java new file mode 100644 index 000000000000..6643289071ef --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * A WindowingFn that windows values into possibly overlapping fixed-size + * timestamp-based windows. + * + *

For example, in order to window data into 10 minute windows that + * update every minute: + *

 {@code
+ * PCollection items = ...;
+ * PCollection windowedItems = items.apply(
+ *   Window.by(SlidingWindows.of(Duration.standardMinutes(10))));
+ * } 
+ */ +public class SlidingWindows extends NonMergingWindowingFn { + + /** + * Amount of time between generated windows. + */ + private final Duration period; + + /** + * Size of the generated windows. + */ + private final Duration size; + + /** + * Offset of the generated windows. + * Windows start at time N * start + offset, where 0 is the epoch. + */ + private final Duration offset; + + /** + * Assigns timestamps into half-open intervals of the form + * [N * period, N * period + size), where 0 is the epoch. + * + *

If {@link SlidingWindows#every} is not called, the period defaults + * to one millisecond. + */ + public static SlidingWindows of(Duration size) { + return new SlidingWindows(new Duration(1), size, Duration.ZERO); + } + + /** + * Returns a new {@code SlidingWindows} with the original size, that assigns + * timestamps into half-open intervals of the form + * [N * period, N * period + size), where 0 is the epoch. + */ + public SlidingWindows every(Duration period) { + return new SlidingWindows(period, size, offset); + } + + /** + * Assigns timestamps into half-open intervals of the form + * [N * period + offset, N * period + offset + size). + * + * @throws IllegalArgumentException if offset is not in [0, period) + */ + public SlidingWindows withOffset(Duration offset) { + return new SlidingWindows(period, size, offset); + } + + private SlidingWindows(Duration period, Duration size, Duration offset) { + if (offset.isShorterThan(Duration.ZERO) + || !offset.isShorterThan(period) + || !size.isLongerThan(Duration.ZERO)) { + throw new IllegalArgumentException( + "SlidingWindows WindowingStrategies must have 0 <= offset < period and 0 < size"); + } + this.period = period; + this.size = size; + this.offset = offset; + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getFixedSizeCoder(size); + } + + @Override + public Collection assignWindows(AssignContext c) { + List windows = + new ArrayList<>((int) (size.getMillis() / period.getMillis())); + Instant timestamp = c.timestamp(); + long lastStart = timestamp.getMillis() + - timestamp.plus(period).minus(offset).getMillis() % period.getMillis(); + for (long start = lastStart; + start > timestamp.minus(size).getMillis(); + start -= period.getMillis()) { + windows.add(new IntervalWindow(new Instant(start), size)); + } + return windows; + } + + @Override + public boolean isCompatible(WindowingFn other) { + if (other instanceof SlidingWindows) { + SlidingWindows that = (SlidingWindows) other; + return period.equals(that.period) + && size.equals(that.size) + && offset.equals(that.offset); + } else { + return false; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java new file mode 100644 index 000000000000..68796c908aba --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java @@ -0,0 +1,321 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.AssignWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * {@code Window} logically divides up or groups the elements of a + * {@link PCollection} into finite windows according to a {@link WindowingFn}. + * The output of {@code Window} contains the same elements as input, but they + * have been logically assigned to windows. The next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}s, including one + * within composite transforms, will group by the combination of keys and + * windows. + + *

See {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} + * for more information about how grouping with windows works. + * + *

Windowing a {@code PCollection} allows chunks of it to be processed + * individually, before the entire {@code PCollection} is available. This is + * especially important for {@code PCollection}s with unbounded size, + * since the full {@code PCollection} is + * never available at once, since more data is continually arriving. + * For {@code PCollection}s with a bounded size (aka. conventional batch mode), + * by default, all data is implicitly in a single window, unless + * {@code Window} is applied. + * + *

For example, a simple form of windowing divides up the data into + * fixed-width time intervals, using {@link FixedWindows}. + * The following example demonstrates how to use {@code Window} in a pipeline + * that counts the number of occurrences of strings each minute: + * + *

 {@code
+ * PCollection items = ...;
+ * PCollection windowed_items = item.apply(
+ *   Window.into(FixedWindows.of(1, TimeUnit.MINUTES)));
+ * PCollection> windowed_counts = windowed_items.apply(
+ *   Count.create());
+ * } 
+ * + *

Let (data, timestamp) denote a data element along with its timestamp. + * Then, if the input to this pipeline consists of + * {("foo", 15s), ("bar", 30s), ("foo", 45s), ("foo", 1m30s)}, + * the output will be + * {(KV("foo", 2), 1m), (KV("bar", 1), 1m), (KV("foo", 1), 2m)} + * + * + *

Several predefined {@link WindowingFn}s are provided: + *

    + *
  • {@link FixedWindows} partitions the timestamps into fixed-width intervals. + *
  • {@link SlidingWindows} places data into overlapping fixed-width intervals. + *
  • {@link Sessions} groups data into sessions where each item in a window + * is separated from the next by no more than a specified gap. + *
+ * + * Additionally, custom {@link WindowingFn}s can be created, by creating new + * subclasses of {@link WindowingFn}. + */ +public class Window { + /** + * Creates a {@code Window} {@code PTransform} with the given name. + * + *

See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more explanation. + * + *

The resulting {@code PTransform} is incomplete, and its input/output + * type is not yet bound. Use {@link Window.Unbound#into} to specify the + * {@link WindowingFn} to use, which will also bind the input/output type of this + * {@code PTransform}. + */ + public static Unbound named(String name) { + return new Unbound().named(name); + } + + /** + * Creates a {@code Window} {@code PTransform} that uses the given + * {@link WindowingFn} to window the data. + * + *

The resulting {@code PTransform}'s types have been bound, with both the + * input and output being a {@code PCollection}, inferred from the types of + * the argument {@code WindowingFn}. It is ready to be applied, or further + * properties can be set on it first. + */ + public static Bound into(WindowingFn fn) { + return new Unbound().into(fn); + } + + /** + * An incomplete {@code Window} transform, with unbound input/output type. + * + *

Before being applied, {@link Window.Unbound#into} must be + * invoked to specify the {@link WindowingFn} to invoke, which will also + * bind the input/output type of this {@code PTransform}. + */ + public static class Unbound { + String name; + + Unbound() {} + + Unbound(String name) { + this.name = name; + } + + /** + * Returns a new {@code Window} transform that's like this + * transform but with the specified name. Does not modify this + * transform. The resulting transform is still incomplete. + * + *

See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more + * explanation. + */ + public Unbound named(String name) { + return new Unbound(name); + } + + /** + * Returns a new {@code Window} {@code PTransform} that's like this + * transform but which will use the given {@link WindowingFn}, and which has + * its input and output types bound. Does not modify this transform. The + * resulting {@code PTransform} is sufficiently specified to be applied, + * but more properties can still be specified. + */ + public Bound into(WindowingFn fn) { + return new Bound<>(name, fn); + } + } + + /** + * A {@code PTransform} that windows the elements of a {@code PCollection}, + * into finite windows according to a user-specified {@code WindowingFn}. + * + * @param The type of elements this {@code Window} is applied to + */ + public static class Bound extends PTransform, PCollection> { + WindowingFn fn; + + Bound(String name, WindowingFn fn) { + this.name = name; + this.fn = fn; + } + + /** + * Returns a new {@code Window} {@code PTransform} that's like this + * {@code PTransform} but with the specified name. Does not + * modify this {@code PTransform}. + * + *

See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more + * explanation. + */ + public Bound named(String name) { + return new Bound<>(name, fn); + } + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal(fn); + } + + @Override + protected Coder getDefaultOutputCoder() { + return getInput().getCoder(); + } + + @Override + protected String getKindString() { + return "Window.Into(" + StringUtils.approximateSimpleName(fn.getClass()) + ")"; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Creates a {@code Window} {@code PTransform} that does not change assigned + * windows, but will cause windows to be merged again as part of the next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}. + */ + public static Remerge remerge() { + return new Remerge(); + } + + /** + * {@code PTransform} that does not change assigned windows, but will cause + * windows to be merged again as part of the next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}. + */ + public static class Remerge extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + WindowingFn windowingFn = getInput().getWindowingFn(); + WindowingFn outputWindowingFn = + (windowingFn instanceof InvalidWindowingFn) + ? ((InvalidWindowingFn) windowingFn).getOriginalWindowingFn() + : windowingFn; + + return input.apply(ParDo.named("Identity").of(new DoFn() { + @Override public void processElement(ProcessContext c) { + c.output(c.element()); + } + })).setWindowingFnInternal(outputWindowingFn); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + PCollection input = transform.getInput(); + + DirectModeExecutionContext executionContext = new DirectModeExecutionContext(); + + TupleTag outputTag = new TupleTag<>(); + DoFn addWindowsDoFn = new AssignWindowsDoFn<>(transform.fn); + DoFnRunner addWindowsRunner = + DoFnRunner.createWithListOutputs( + context.getPipelineOptions(), + addWindowsDoFn, + PTuple.empty(), + outputTag, + new ArrayList>(), + executionContext.getStepContext(context.getStepName(transform)), + context.getAddCounterMutator()); + + addWindowsRunner.startBundle(); + + // Process input elements. + for (DirectPipelineRunner.ValueWithMetadata inputElem + : context.getPCollectionValuesWithMetadata(input)) { + executionContext.setKey(inputElem.getKey()); + addWindowsRunner.processElement(inputElem.getWindowedValue()); + } + + addWindowsRunner.finishBundle(); + + context.setPCollectionValuesWithMetadata( + transform.getOutput(), + executionContext.getOutput(outputTag)); + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + DataflowPipelineTranslator.registerTransformTranslator( + Bound.class, + new DataflowPipelineTranslator.TransformTranslator() { + @Override + public void translate( + Bound transform, + DataflowPipelineTranslator.TranslationContext context) { + translateHelper(transform, context); + } + }); + } + + private static void translateHelper( + Bound transform, + DataflowPipelineTranslator.TranslationContext context) { + context.addStep(transform, "Bucket"); + context.addInput(PropertyNames.PARALLEL_INPUT, transform.getInput()); + context.addOutput(PropertyNames.OUTPUT, transform.getOutput()); + + byte[] serializedBytes = serializeToByteArray(transform.fn); + String serializedJson = byteArrayToJsonString(serializedBytes); + assert Arrays.equals(serializedBytes, + jsonStringToByteArray(serializedJson)); + context.addInput(PropertyNames.SERIALIZED_FN, serializedJson); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingFn.java new file mode 100644 index 000000000000..0f049372555b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingFn.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Collection; + +/** + * The argument to the {@link Window} transform used to assign elements into + * windows and to determine how windows are merged. See {@link Window} for more + * information on how {@code WindowingFn}s are used and for a library of + * predefined {@code WindowingFn}s. + * + *

Users will generally want to use the predefined + * {@code WindowingFn}s, but it is also possible to create new + * subclasses. + * TODO: Describe how to properly create {@code WindowingFn}s. + * + * @param type of elements being windowed + * @param {@link BoundedWindow} subclass used to represent the + * windows used by this {@code WindowingFn} + */ +public abstract class WindowingFn + implements Serializable { + + /** + * Information available when running {@link #assignWindows}. + */ + public abstract class AssignContext { + /** + * Returns the current element. + */ + public abstract T element(); + + /** + * Returns the timestamp of the current element. + */ + public abstract Instant timestamp(); + + /** + * Returns the windows the current element was in, prior to this + * {@code AssignFn} being called. + */ + public abstract Collection windows(); + } + + /** + * Given a timestamp and element, returns the set of windows into which it + * should be placed. + */ + public abstract Collection assignWindows(AssignContext c) throws Exception; + + /** + * Information available when running {@link #mergeWindows}. + */ + public abstract class MergeContext { + /** + * Returns the current set of windows. + */ + public abstract Collection windows(); + + /** + * Signals to the framework that the windows in {@code toBeMerged} should + * be merged together to form {@code mergeResult}. + * + *

{@code toBeMerged} should be a subset of {@link #windows} + * and disjoint from the {@code toBeMerged} set of previous calls + * to {@code merge}. + * + *

{@code mergeResult} must either not be in {@link #windows} or be in + * {@code toBeMerged}. + * + * @throws IllegalArgumentException if any elements of toBeMerged are not + * in windows(), or have already been merged + */ + public abstract void merge(Collection toBeMerged, W mergeResult) + throws Exception; + } + + /** + * Does whatever merging of windows is necessary. + * + *

See {@link MergeOverlappingIntervalWindows#mergeWindows} for an + * example of how to override this method. + */ + public abstract void mergeWindows(MergeContext c) throws Exception; + + /** + * Returns whether this performs the same merging as the given + * {@code WindowingFn}. + */ + public abstract boolean isCompatible(WindowingFn other); + + /** + * Returns the {@link Coder} used for serializing the windows used + * by this windowingFn. + */ + public abstract Coder windowCoder(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AbstractWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AbstractWindowSet.java new file mode 100644 index 000000000000..dda2488dac34 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AbstractWindowSet.java @@ -0,0 +1,170 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.util.Arrays; +import java.util.Collection; + +/** + * Abstract class representing a set of active windows for a key. + */ +abstract class AbstractWindowSet { + /** + * Hook for determining how to keep track of active windows and when they + * should be marked as complete. + */ + interface ActiveWindowManager { + /** + * Notes that a window has been added to the active set. + * + *

The given window must not already be active. + */ + void addWindow(W window) throws Exception; + + /** + * Notes that a window has been explicitly removed from the active set. + * + *

The given window must currently be active. + * + *

Windows are implicitly removed from the active set when they are + * complete, and this method will not be called. This method is called when + * a window is merged into another and thus is no longer active. + */ + void removeWindow(W window) throws Exception; + } + + /** + * Wrapper around AbstractWindowSet that provides the MergeContext interface. + */ + static class WindowMergeContext + extends WindowingFn.MergeContext { + private final AbstractWindowSet windowSet; + + public WindowMergeContext( + AbstractWindowSet windowSet, + WindowingFn windowingFn) { + ((WindowingFn) windowingFn).super(); + this.windowSet = windowSet; + } + + @Override public Collection windows() { + return windowSet.windows(); + } + + @Override public void merge(Collection toBeMerged, W mergeResult) throws Exception { + windowSet.merge(toBeMerged, mergeResult); + } + } + + protected final K key; + protected final WindowingFn windowingFn; + protected final Coder inputCoder; + protected final DoFnProcessContext> context; + protected final ActiveWindowManager activeWindowManager; + + protected AbstractWindowSet( + K key, + WindowingFn windowingFn, + Coder inputCoder, + DoFnProcessContext> context, + ActiveWindowManager activeWindowManager) { + this.key = key; + this.windowingFn = windowingFn; + this.inputCoder = inputCoder; + this.context = context; + this.activeWindowManager = activeWindowManager; + } + + /** + * Returns the set of known windows. + */ + protected abstract Collection windows(); + + /** + * Returns the final value of the elements in the given window. + * + *

Illegal to call if the window does not exist in the set. + */ + protected abstract VO finalValue(W window) throws Exception; + + /** + * Adds the given value in the given window to the set. + * + *

If the window already exists, puts the element into that window. + * If not, adds the window to the set first, then puts the element + * in the window. + */ + protected abstract void put(W window, VI value) throws Exception; + + /** + * Removes the given window from the set. + * + *

Illegal to call if the window does not exist in the set. + * + *

{@code AbstractWindowSet} subclasses may throw + * {@link UnsupportedOperationException} if they do not support removing + * windows. + */ + protected abstract void remove(W window) throws Exception; + + /** + * Instructs this set to merge the windows in toBeMerged into mergeResult. + * + *

{@code toBeMerged} should be a subset of {@link #windows} + * and disjoint from the {@code toBeMerged} set of previous calls + * to {@code merge}. + * + *

{@code mergeResult} must either not be in {@link @windows} or be in + * {@code toBeMerged}. + * + *

{@code AbstractWindowSet} subclasses may throw + * {@link UnsupportedOperationException} if they do not support merging windows. + */ + protected abstract void merge(Collection toBeMerged, W mergeResult) throws Exception; + + /** + * Returns whether this window set contains the given window. + * + *

{@code AbstractWindowSet} subclasses may throw + * {@link UnsupportedOperationException} if they do not support querying for + * which windows are active. If this is the case, callers must ensure they + * do not call {@link #finalValue} on non-existent windows. + */ + protected abstract boolean contains(W window); + + /** + * Marks the window as complete, causing its elements to be emitted. + */ + public void markCompleted(W window) throws Exception { + VO value = finalValue(window); + remove(window); + context.outputWindowedValue( + KV.of(key, value), + window.maxTimestamp(), + Arrays.asList(window)); + } + + /** + * Hook for WindowSets to take action before they are deleted. + */ + protected void flush() throws Exception {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AggregatorImpl.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AggregatorImpl.java new file mode 100644 index 000000000000..e71bf7f8a7f0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AggregatorImpl.java @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MIN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * An implementation of the {@code Aggregator} interface. + * + * @param the type of input values + * @param the type of accumulator values + * @param the type of output value + */ +public class AggregatorImpl implements Aggregator { + + private final Counter counter; + + /* + * Constructs a new aggregator with the given name and aggregation logic + * specified in the CombineFn argument. The underlying counter is + * automatically added into the provided CounterSet. + * + *

If a counter with the same name already exists, it will be + * reused, as long as it has the same type. + */ + public AggregatorImpl(String name, + CombineFn combiner, + CounterSet.AddCounterMutator addCounterMutator) { + this((Counter) constructCounter(name, combiner), addCounterMutator); + } + + /* + * Constructs a new aggregator with the given name and aggregation logic + * specified in the SerializableFunction argument. The underlying counter is + * automatically added into the provided CounterSet. + * + *

If a counter with the same name already exists, it will be + * reused, as long as it has the same type. + */ + public AggregatorImpl(String name, + SerializableFunction, VO> combiner, + CounterSet.AddCounterMutator addCounterMutator) { + this((Counter) constructCounter(name, combiner), addCounterMutator); + } + + private AggregatorImpl(Counter counter, + CounterSet.AddCounterMutator addCounterMutator) { + try { + this.counter = addCounterMutator.addCounter(counter); + } catch (IllegalArgumentException ex) { + throw new IllegalArgumentException( + "aggregator's name collides with an existing aggregator " + + "or system-provided counter of an incompatible type"); + } + } + + private static Counter constructCounter(String name, Object combiner) { + if (combiner.getClass() == Sum.SumIntegerFn.class) { + return Counter.ints(name, SUM); + } else if (combiner.getClass() == Sum.SumLongFn.class) { + return Counter.longs(name, SUM); + } else if (combiner.getClass() == Sum.SumDoubleFn.class) { + return Counter.doubles(name, SUM); + } else if (combiner.getClass() == Min.MinIntegerFn.class) { + return Counter.ints(name, MIN); + } else if (combiner.getClass() == Min.MinLongFn.class) { + return Counter.longs(name, MIN); + } else if (combiner.getClass() == Min.MinDoubleFn.class) { + return Counter.doubles(name, MIN); + } else if (combiner.getClass() == Max.MaxIntegerFn.class) { + return Counter.ints(name, MAX); + } else if (combiner.getClass() == Max.MaxLongFn.class) { + return Counter.longs(name, MAX); + } else if (combiner.getClass() == Max.MaxDoubleFn.class) { + return Counter.doubles(name, MAX); + } else { + throw new IllegalArgumentException("unsupported combiner in Aggregator: " + + combiner.getClass().getName()); + } + } + + @Override + public void addValue(VI value) { + counter.addValue(value); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiErrorExtractor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiErrorExtractor.java new file mode 100644 index 000000000000..ad181cee40b3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiErrorExtractor.java @@ -0,0 +1,104 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.googleapis.json.GoogleJsonError; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.HttpStatusCodes; +import com.google.common.annotations.VisibleForTesting; + +import java.io.IOException; + +/** + * Translates exceptions from API calls into higher-level meaning, while allowing injectability + * for testing how API errors are handled. + */ +public class ApiErrorExtractor { + + public static final int STATUS_CODE_CONFLICT = 409; + public static final int STATUS_CODE_RANGE_NOT_SATISFIABLE = 416; + + /** + * Determines if the given exception indicates 'item not found'. + */ + public boolean itemNotFound(IOException e) { + if (e instanceof GoogleJsonResponseException) { + return (getHttpStatusCode((GoogleJsonResponseException) e)) == + HttpStatusCodes.STATUS_CODE_NOT_FOUND; + } + return false; + } + + /** + * Determines if the given GoogleJsonError indicates 'item not found'. + */ + public boolean itemNotFound(GoogleJsonError e) { + return e.getCode() == HttpStatusCodes.STATUS_CODE_NOT_FOUND; + } + + /** + * Checks if HTTP status code indicates the error specified. + */ + private boolean hasHttpCode(IOException e, int code) { + if (e instanceof GoogleJsonResponseException) { + return (getHttpStatusCode((GoogleJsonResponseException) e)) == code; + } + return false; + } + + /** + * Determines if the given exception indicates 'conflict' (already exists). + */ + public boolean alreadyExists(IOException e) { + return hasHttpCode(e, STATUS_CODE_CONFLICT); + } + + /** + * Determines if the given exception indicates 'range not satisfiable'. + */ + public boolean rangeNotSatisfiable(IOException e) { + return hasHttpCode(e, STATUS_CODE_RANGE_NOT_SATISFIABLE); + } + + /** + * Determines if the given exception indicates 'access denied'. + */ + public boolean accessDenied(GoogleJsonResponseException e) { + return getHttpStatusCode(e) == HttpStatusCodes.STATUS_CODE_FORBIDDEN; + } + + /** + * Determines if the given exception indicates 'access denied', recursively checking inner + * getCause() if outer exception isn't an instance of the correct class. + */ + public boolean accessDenied(IOException e) { + return (e.getCause() != null) && + (e.getCause() instanceof GoogleJsonResponseException) && + accessDenied((GoogleJsonResponseException) e.getCause()); + } + + /** + * Returns HTTP status code from the given exception. + * + * Note: GoogleJsonResponseException.getStatusCode() method is marked final therefore + * it cannot be mocked using Mockito. We use this helper so that we can override it in tests. + */ + @VisibleForTesting + int getHttpStatusCode(GoogleJsonResponseException e) { + return e.getStatusCode(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java new file mode 100644 index 000000000000..f3b57a4508b0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.lang.reflect.InvocationTargetException; + +/** Stores whether we are running within AppEngine or not. */ +public class AppEngineEnvironment { + /** + * True if running inside of AppEngine, false otherwise. + */ + @Deprecated + public static final boolean IS_APP_ENGINE = isAppEngine(); + + /** + * Attempts to detect whether we are inside of AppEngine. + *

+ * Purposely copied and left private from private code.google.common.util.concurrent.MoreExecutors#isAppEngine. + * + * @return true if we are inside of AppEngine, false otherwise. + */ + static boolean isAppEngine() { + if (System.getProperty("com.google.appengine.runtime.environment") == null) { + return false; + } + try { + // If the current environment is null, we're not inside AppEngine. + return Class.forName("com.google.apphosting.api.ApiProxy") + .getMethod("getCurrentEnvironment") + .invoke(null) != null; + } catch (ClassNotFoundException e) { + // If ApiProxy doesn't exist, we're not on AppEngine at all. + return false; + } catch (InvocationTargetException e) { + // If ApiProxy throws an exception, we're not in a proper AppEngine environment. + return false; + } catch (IllegalAccessException e) { + // If the method isn't accessible, we're not on a supported version of AppEngine; + return false; + } catch (NoSuchMethodException e) { + // If the method doesn't exist, we're not on a supported version of AppEngine; + return false; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java new file mode 100644 index 000000000000..7649a8c63724 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * {@link DoFn} that tags elements of a PCollection with windows, according + * to the provided {@link WindowingFn}. + * @param Type of elements being windowed + * @param Window type + */ +public class AssignWindowsDoFn extends DoFn { + private WindowingFn fn; + + public AssignWindowsDoFn(WindowingFn fn) { + this.fn = fn; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + final DoFnProcessContext context = (DoFnProcessContext) c; + Collection windows = + ((WindowingFn) fn).assignWindows( + ((WindowingFn) fn).new AssignContext() { + @Override + public T element() { + return context.element(); + } + + @Override + public Instant timestamp() { + return context.timestamp(); + } + + @Override + public Collection windows() { + return context.windows(); + } + }); + + context.outputWindowedValue(context.element(), context.timestamp(), windows); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java new file mode 100644 index 000000000000..78e8e0538b82 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.common.base.Preconditions; + +/** + * Implementation of {@link BackOff} that increases the back off period for each retry attempt + * using a randomization function that grows exponentially. + *

+ * Example: The initial interval is .5 seconds and the maximum number of retries is 10. + * For 10 tries the sequence will be (values in seconds): + *

+ * + *
+   retry#      retry_interval     randomized_interval
+   1             0.5                [0.25,   0.75]
+   2             0.75               [0.375,  1.125]
+   3             1.125              [0.562,  1.687]
+   4             1.687              [0.8435, 2.53]
+   5             2.53               [1.265,  3.795]
+   6             3.795              [1.897,  5.692]
+   7             5.692              [2.846,  8.538]
+   8             8.538              [4.269, 12.807]
+   9            12.807              [6.403, 19.210]
+   10           {@link BackOff#STOP}
+ * 
+ * + *

+ * Implementation is not thread-safe. + *

+ */ +public class AttemptBoundedExponentialBackOff implements BackOff { + public static final double DEFAULT_MULTIPLIER = 1.5; + public static final double DEFAULT_RANDOMIZATION_FACTOR = 0.5; + private final int maximumNumberOfAttempts; + private final long initialIntervalMillis; + private int currentAttempt; + + public AttemptBoundedExponentialBackOff(int maximumNumberOfAttempts, long initialIntervalMillis) { + Preconditions.checkArgument(maximumNumberOfAttempts > 0, + "Maximum number of attempts must be greater than zero."); + Preconditions.checkArgument(initialIntervalMillis > 0, + "Initial interval must be greater than zero."); + this.maximumNumberOfAttempts = maximumNumberOfAttempts; + this.initialIntervalMillis = initialIntervalMillis; + reset(); + } + + @Override + public void reset() { + currentAttempt = 1; + } + + @Override + public long nextBackOffMillis() { + if (currentAttempt >= maximumNumberOfAttempts) { + return BackOff.STOP; + } + double currentIntervalMillis = initialIntervalMillis + * Math.pow(DEFAULT_MULTIPLIER, currentAttempt - 1); + double randomOffset = (Math.random() * 2 - 1) + * DEFAULT_RANDOMIZATION_FACTOR * currentIntervalMillis; + currentAttempt += 1; + return Math.round(currentIntervalMillis + randomOffset); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Base64Utils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Base64Utils.java new file mode 100644 index 000000000000..0ea25102e132 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Base64Utils.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Utilities related to Base64 encoding. + */ +public class Base64Utils { + /** + * Returns an upper bound of the length of non-chunked Base64 encoded version + * of the string of the given length. + */ + public static int getBase64Length(int length) { + return 4 * ((length + 2) / 3); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchModeExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchModeExecutionContext.java new file mode 100644 index 000000000000..2d42407c9437 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchModeExecutionContext.java @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; +import com.google.cloud.dataflow.sdk.values.CodedTupleTagMap; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * {@link ExecutionContext} for use in batch mode. + */ +public class BatchModeExecutionContext extends ExecutionContext { + private Object key; + private Map> timers = new HashMap<>(); + + /** + * Create a new {@link ExecutionContext.StepContext}. + */ + @Override + public ExecutionContext.StepContext createStepContext(String stepName) { + return new StepContext(stepName); + } + + /** + * Sets the key of the work currently being processed. + */ + public void setKey(Object key) { + this.key = key; + } + + /** + * Returns the key of the work currently being processed. + * + *

If there is not a currently defined key, returns null. + */ + public Object getKey() { + return key; + } + + @Override + public void setTimer(String timer, Instant timestamp) { + Map keyTimers = timers.get(getKey()); + if (keyTimers == null) { + keyTimers = new HashMap<>(); + timers.put(getKey(), keyTimers); + } + keyTimers.put(timer, timestamp); + } + + @Override + public void deleteTimer(String timer) { + Map keyTimers = timers.get(getKey()); + if (keyTimers != null) { + keyTimers.remove(timer); + } + } + + public List> getAllTimers() { + List> result = new ArrayList<>(); + for (Map.Entry> keyTimers : timers.entrySet()) { + for (Map.Entry timer : keyTimers.getValue().entrySet()) { + result.add(TimerOrElement.timer(timer.getKey(), timer.getValue(), keyTimers.getKey())); + } + } + return result; + } + + /** + * {@link ExecutionContext.StepContext} used in batch mode. + */ + class StepContext extends ExecutionContext.StepContext { + private Map, Object>> state = new HashMap<>(); + private Map, List>> tagLists = new HashMap<>(); + + StepContext(String stepName) { + super(stepName); + } + + @Override + public void store(CodedTupleTag tag, T value) { + Map, Object> perKeyState = state.get(getKey()); + if (perKeyState == null) { + perKeyState = new HashMap<>(); + state.put(getKey(), perKeyState); + } + perKeyState.put(tag, value); + } + + @Override + public CodedTupleTagMap lookup(List> tags) { + Map, Object> perKeyState = state.get(getKey()); + Map, Object> map = new HashMap<>(); + if (perKeyState != null) { + for (CodedTupleTag tag : tags) { + map.put(tag, perKeyState.get(tag)); + } + } + return CodedTupleTagMap.of(map); + } + + @Override + public void writeToTagList(CodedTupleTag tag, T value, Instant timestamp) { + Map, List> perKeyTagLists = tagLists.get(getKey()); + if (perKeyTagLists == null) { + perKeyTagLists = new HashMap<>(); + tagLists.put(getKey(), perKeyTagLists); + } + List tagList = perKeyTagLists.get(tag); + if (tagList == null) { + tagList = new ArrayList<>(); + perKeyTagLists.put(tag, tagList); + } + tagList.add(value); + } + + @Override + public void deleteTagList(CodedTupleTag tag) { + Map, List> perKeyTagLists = tagLists.get(getKey()); + if (perKeyTagLists != null) { + perKeyTagLists.remove(tag); + } + } + + @Override + public Iterable readTagList(CodedTupleTag tag) { + Map, List> perKeyTagLists = tagLists.get(getKey()); + if (perKeyTagLists == null || perKeyTagLists.get(tag) == null) { + return new ArrayList(); + } + List result = new ArrayList(); + for (Object element : perKeyTagLists.get(tag)) { + result.add((T) element); + } + return result; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java new file mode 100644 index 000000000000..c241ee2f2591 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java @@ -0,0 +1,240 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableDataInsertAllRequest; +import com.google.api.services.bigquery.model.TableDataInsertAllResponse; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Inserts rows into BigQuery. + */ +public class BigQueryTableInserter { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryTableInserter.class); + + // Approximate amount of table data to upload per InsertAll request. + private static final long UPLOAD_BATCH_SIZE = 64 * 1024; + + private final Bigquery client; + private final TableReference ref; + + /** + * Constructs a new row inserter. + * + * @param client a BigQuery client + * @param ref identifies the table to insert into + */ + public BigQueryTableInserter(Bigquery client, TableReference ref) { + this.client = client; + this.ref = ref; + } + + /** + * Insert all rows from the given iterator. + */ + public void insertAll(Iterator rowIterator) throws IOException { + insertAll(rowIterator, null); + } + + /** + * Insert all rows from the given iterator using specified insertIds if not null. + */ + public void insertAll(Iterator rowIterator, + @Nullable Iterator insertIdIterator) throws IOException { + // Upload in batches. + List rows = new LinkedList<>(); + int numInserted = 0; + int dataSize = 0; + while (rowIterator.hasNext()) { + TableRow row = rowIterator.next(); + TableDataInsertAllRequest.Rows out = new TableDataInsertAllRequest.Rows(); + if (insertIdIterator != null) { + if (insertIdIterator.hasNext()) { + out.setInsertId(insertIdIterator.next()); + } else { + throw new AssertionError("If insertIdIterator is not null it needs to have at least " + + "as many elements as rowIterator"); + } + } + out.setJson(row.getUnknownKeys()); + rows.add(out); + + dataSize += row.toString().length(); + if (dataSize >= UPLOAD_BATCH_SIZE || !rowIterator.hasNext()) { + TableDataInsertAllRequest content = new TableDataInsertAllRequest(); + content.setRows(rows); + + LOG.info("Number of rows in BigQuery insert: {}", rows.size()); + numInserted += rows.size(); + + Bigquery.Tabledata.InsertAll insert = client.tabledata() + .insertAll(ref.getProjectId(), ref.getDatasetId(), ref.getTableId(), + content); + TableDataInsertAllResponse response = insert.execute(); + List errors = response + .getInsertErrors(); + if (errors != null && !errors.isEmpty()) { + throw new IOException("Insert failed: " + errors); + } + + dataSize = 0; + rows.clear(); + } + } + + LOG.info("Number of rows written to BigQuery: {}", numInserted); + } + + /** + * Retrieves or creates the table. + *

+ * The table is checked to conform to insertion requirements as specified + * by WriteDisposition and CreateDisposition. + *

+ * If table truncation is requested (WriteDisposition.WRITE_TRUNCATE), then + * this will re-create the table if necessary to ensure it is empty. + *

+ * If an empty table is required (WriteDisposition.WRITE_EMPTY), then this + * will fail if the table exists and is not empty. + *

+ * When constructing a table, a {@code TableSchema} must be available. If a + * schema is provided, then it will be used. If no schema is provided, but + * an existing table is being cleared (WRITE_TRUNCATE option above), then + * the existing schema will be re-used. If no schema is available, then an + * {@code IOException} is thrown. + */ + public Table getOrCreateTable( + WriteDisposition writeDisposition, + CreateDisposition createDisposition, + @Nullable TableSchema schema) throws IOException { + // Check if table already exists. + Bigquery.Tables.Get get = client.tables() + .get(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + Table table = null; + try { + table = get.execute(); + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (!errorExtractor.itemNotFound(e) || + createDisposition != CreateDisposition.CREATE_IF_NEEDED) { + // Rethrow. + throw e; + } + } + + // If we want an empty table, and it isn't, then delete it first. + if (table != null) { + if (writeDisposition == WriteDisposition.WRITE_APPEND) { + return table; + } + + boolean empty = isEmpty(); + if (empty) { + if (writeDisposition == WriteDisposition.WRITE_TRUNCATE) { + LOG.info("Empty table found, not removing {}", BigQueryIO.toTableSpec(ref)); + } + return table; + + } else if (writeDisposition == WriteDisposition.WRITE_EMPTY) { + throw new IOException("WriteDisposition is WRITE_EMPTY, " + + "but table is not empty"); + } + + // Reuse the existing schema if none was provided. + if (schema == null) { + schema = table.getSchema(); + } + + // Delete table and fall through to re-creating it below. + LOG.info("Deleting table {}", BigQueryIO.toTableSpec(ref)); + Bigquery.Tables.Delete delete = client.tables() + .delete(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + delete.execute(); + } + + if (schema == null) { + throw new IllegalArgumentException( + "Table schema required for new table."); + } + + // Create the table. + return tryCreateTable(schema); + } + + /** + * Checks if a table is empty. + */ + public boolean isEmpty() throws IOException { + Bigquery.Tabledata.List list = client.tabledata() + .list(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + list.setMaxResults(1L); + TableDataList dataList = list.execute(); + + return dataList.getRows() == null || dataList.getRows().isEmpty(); + } + + /** + * Tries to create the BigQuery table. + * If a table with the same name already exists in the dataset, the table + * creation fails, and the function returns null. In such a case, + * the existing table doesn't necessarily have the same schema as specified + * by the parameter. + * + * @param schema Schema of the new BigQuery table. + * @return The newly created BigQuery table information, or null if the table + * with the same name already exists. + * @throws IOException if other error than already existing table occurs. + */ + @Nullable + public Table tryCreateTable(TableSchema schema) throws IOException { + LOG.info("Trying to create BigQuery table: {}", BigQueryIO.toTableSpec(ref)); + + Table content = new Table(); + content.setTableReference(ref); + content.setSchema(schema); + + try { + return client.tables() + .insert(ref.getProjectId(), ref.getDatasetId(), content) + .execute(); + } catch (IOException e) { + if (new ApiErrorExtractor().alreadyExists(e)) { + LOG.info("The BigQuery table already exists."); + return null; + } + throw e; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java new file mode 100644 index 000000000000..a6ea658ae3f4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Data; +import com.google.api.client.util.Preconditions; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; + +/** + * Iterates over all rows in a table. + */ +public class BigQueryTableRowIterator implements Iterator, Closeable { + + private final Bigquery client; + private final TableReference ref; + private TableSchema schema; + private String pageToken; + private Iterator rowIterator; + // Set true when the final page is seen from the service. + private boolean lastPage = false; + + public BigQueryTableRowIterator(Bigquery client, TableReference ref) { + this.client = client; + this.ref = ref; + } + + @Override + public boolean hasNext() { + try { + if (!isOpen()) { + open(); + } + + if (!rowIterator.hasNext() && !lastPage) { + readNext(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return rowIterator.hasNext(); + } + + /** + * Adjusts a field returned from the API to + * match the type that will be seen when run on the + * backend service. The end result is: + * + *

    + *
  • Nulls are {@code null}. + *
  • Repeated fields are lists. + *
  • Record columns are {@link TableRow}s. + *
  • {@code BOOLEAN} columns are JSON booleans, hence Java {@link Boolean}s. + *
  • {@code FLOAT} columns are JSON floats, hence Java {@link Double}s. + *
  • Every other atomic type is a {@link String}. + *

+ * + *

Note that currently integers are encoded as strings to match + * the behavior of the backend service. + */ + private Object getTypedCellValue(TableFieldSchema fieldSchema, Object v) { + // In the input from the BQ API, atomic types all come in as + // strings, while on the Dataflow service they have more precise + // types. + + if (Data.isNull(v)) { + return null; + } + + if (Objects.equals(fieldSchema.getMode(), "REPEATED")) { + TableFieldSchema elementSchema = fieldSchema.clone().setMode("REQUIRED"); + List rawValues = (List) v; + List values = new ArrayList(rawValues.size()); + for (Object element : rawValues) { + values.add(getTypedCellValue(elementSchema, element)); + } + return values; + } + + if (fieldSchema.getType().equals("RECORD")) { + return getTypedTableRow(fieldSchema.getFields(), (TableRow) v); + } + + if (fieldSchema.getType().equals("FLOAT")) { + return Double.parseDouble((String) v); + } + + if (fieldSchema.getType().equals("BOOLEAN")) { + return Boolean.parseBoolean((String) v); + } + + return v; + } + + private TableRow getTypedTableRow(List fields, TableRow rawRow) { + List cells = rawRow.getF(); + Preconditions.checkState(cells.size() == fields.size()); + + Iterator cellIt = cells.iterator(); + Iterator fieldIt = fields.iterator(); + + TableRow row = new TableRow(); + while (cellIt.hasNext()) { + TableCell cell = cellIt.next(); + TableFieldSchema fieldSchema = fieldIt.next(); + row.set(fieldSchema.getName(), getTypedCellValue(fieldSchema, cell.getV())); + } + return row; + } + + @Override + public TableRow next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + // Embed schema information into the raw row, so that values have an + // associated key. This matches how rows are read when using the + // DataflowPipelineRunner. + return getTypedTableRow(schema.getFields(), rowIterator.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + private void readNext() throws IOException { + Bigquery.Tabledata.List list = client.tabledata() + .list(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + if (pageToken != null) { + list.setPageToken(pageToken); + } + + TableDataList result = list.execute(); + pageToken = result.getPageToken(); + rowIterator = result.getRows() != null ? result.getRows().iterator() : + Collections.emptyIterator(); + + // The server may return a page token indefinitely on a zero-length table. + if (pageToken == null || + result.getTotalRows() != null && result.getTotalRows() == 0) { + lastPage = true; + } + } + + @Override + public void close() throws IOException { + // Prevent any further requests. + lastPage = true; + } + + private boolean isOpen() { + return schema != null; + } + + /** + * Opens the table for read. + * @throws IOException on failure + */ + private void open() throws IOException { + // Get table schema. + Bigquery.Tables.Get get = client.tables() + .get(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + Table table = get.execute(); + schema = table.getSchema(); + + // Read the first page of results. + readNext(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferingWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferingWindowSet.java new file mode 100644 index 000000000000..4801d6d64c3c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferingWindowSet.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.WindowUtils.bufferTag; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.SetCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A WindowSet allowing windows to be merged and deleted. + */ +class BufferingWindowSet + extends AbstractWindowSet, W> { + /** + * Tag for storing the merge tree, the data structure that keeps + * track of which windows have been merged together. + */ + private final CodedTupleTag>> mergeTreeTag = + CodedTupleTag.of( + "mergeTree", + MapCoder.of( + windowingFn.windowCoder(), + SetCoder.of(windowingFn.windowCoder()))); + + /** + * A map of live windows to windows that were merged into them. + * + *

The keys of the map correspond to the set of (merged) windows and the values + * are the no-longer-present windows that were merged into the keys. A given + * window can appear in both the key and value of a single entry, but other at + * most once across all keys and values. + */ + private final Map> mergeTree; + + /** + * Used to determine if writing the mergeTree (which is relatively stable) + * is necessary. + */ + private final Map> originalMergeTree; + + protected BufferingWindowSet( + K key, + WindowingFn windowingFn, + Coder inputCoder, + DoFnProcessContext>> context, + ActiveWindowManager activeWindowManager) throws Exception { + super(key, windowingFn, inputCoder, context, activeWindowManager); + + mergeTree = emptyIfNull( + context.context.stepContext.lookup(Arrays.asList(mergeTreeTag)) + .get(mergeTreeTag)); + + originalMergeTree = deepCopy(mergeTree); + } + + @Override + public void put(W window, V value) throws Exception { + context.context.stepContext.writeToTagList( + bufferTag(window, windowingFn.windowCoder(), inputCoder), + value, + context.timestamp()); + if (!mergeTree.containsKey(window)) { + mergeTree.put(window, new HashSet()); + activeWindowManager.addWindow(window); + } + } + + @Override + public void remove(W window) throws Exception { + mergeTree.remove(window); + activeWindowManager.removeWindow(window); + } + + @Override + public void merge(Collection otherWindows, W newWindow) throws Exception { + Set subWindows = mergeTree.get(newWindow); + if (subWindows == null) { + subWindows = new HashSet<>(); + } + for (W other : otherWindows) { + if (!mergeTree.containsKey(other)) { + throw new IllegalArgumentException("Tried to merge a non-existent window: " + other); + } + subWindows.addAll(mergeTree.get(other)); + subWindows.add(other); + remove(other); + } + mergeTree.put(newWindow, subWindows); + activeWindowManager.addWindow(newWindow); + } + + @Override + public Collection windows() { + return Collections.unmodifiableSet(mergeTree.keySet()); + } + + @Override + public boolean contains(W window) { + return mergeTree.containsKey(window); + } + + @Override + protected Iterable finalValue(W window) throws Exception { + if (!contains(window)) { + throw new IllegalStateException("finalValue called for non-existent window"); + } + + List toEmit = new ArrayList<>(); + // This is the set of windows that we're currently emitting. + Set curWindows = new HashSet<>(); + curWindows.add(window); + curWindows.addAll(mergeTree.get(window)); + + // This is the set of unflushed windows (for preservation detection). + Set otherWindows = new HashSet<>(); + for (Map.Entry> entry : mergeTree.entrySet()) { + if (!entry.getKey().equals(window)) { + otherWindows.add(entry.getKey()); + otherWindows.addAll(entry.getValue()); + } + } + + for (W curWindow : curWindows) { + Iterable items = context.context.stepContext.readTagList(bufferTag( + curWindow, windowingFn.windowCoder(), inputCoder)); + for (V item : items) { + toEmit.add(item); + } + context.context.stepContext.deleteTagList(bufferTag( + curWindow, windowingFn.windowCoder(), inputCoder)); + } + + return toEmit; + } + + @Override + public void flush() throws Exception { + if (!mergeTree.equals(originalMergeTree)) { + context.context.stepContext.store(mergeTreeTag, mergeTree); + } + } + + private static Map> emptyIfNull(Map> input) { + if (input == null) { + return new HashMap<>(); + } else { + for (Map.Entry> entry : input.entrySet()) { + if (entry.getValue() == null) { + entry.setValue(new HashSet()); + } + } + return input; + } + } + + private Map> deepCopy(Map> mergeTree) { + Map> newMergeTree = new HashMap<>(); + for (Map.Entry> entry : mergeTree.entrySet()) { + newMergeTree.put(entry.getKey(), new HashSet(entry.getValue())); + } + return newMergeTree; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudCounterUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudCounterUtils.java new file mode 100644 index 000000000000..f96ba486f24d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudCounterUtils.java @@ -0,0 +1,104 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Utilities for working with CloudCounters. + */ +public class CloudCounterUtils { + private static final Logger LOG = LoggerFactory.getLogger(CloudCounterUtils.class); + + public static List extractCounters( + CounterSet counters, boolean delta) { + synchronized (counters) { + List cloudCounters = new ArrayList<>(counters.size()); + for (Counter counter : counters) { + try { + MetricUpdate cloudCounter = extractCounter(counter, delta); + if (cloudCounter != null) { + cloudCounters.add(cloudCounter); + } + } catch (IllegalArgumentException exn) { + LOG.warn("Error extracting counter value: ", exn); + } + } + return cloudCounters; + } + } + + public static MetricUpdate extractCounter(Counter counter, boolean delta) { + // TODO: Omit no-op counter updates, for counters whose + // values haven't changed since the last time we sent them. + synchronized (counter) { + MetricStructuredName name = new MetricStructuredName(); + name.setName(counter.getName()); + MetricUpdate metricUpdate = new MetricUpdate() + .setName(name) + .setKind(counter.getKind().name()) + .setCumulative(!delta); + switch (counter.getKind()) { + case SUM: + case MAX: + case MIN: + case AND: + case OR: + metricUpdate.setScalar(CloudObject.forKnownType(counter.getAggregate(delta))); + break; + case MEAN: { + long countUpdate = counter.getCount(delta); + if (countUpdate <= 0) { + return null; + } + metricUpdate.setMeanSum(CloudObject.forKnownType(counter.getAggregate(delta))); + metricUpdate.setMeanCount(CloudObject.forKnownType(countUpdate)); + break; + } + case SET: { + Set values = counter.getSet(delta); + if (values.isEmpty()) { + return null; + } + Set encodedSet = new HashSet(values.size()); + for (Object value : values) { + encodedSet.add(CloudObject.forKnownType(value)); + } + metricUpdate.setSet(encodedSet); + break; + } + default: + throw new IllegalArgumentException("unexpected kind of counter"); + } + if (delta) { + counter.resetDelta(); + } + return metricUpdate; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java new file mode 100644 index 000000000000..ad57b9953631 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nullable; + +/** A utility for manipulating well-known cloud types. */ +enum CloudKnownType { + TEXT("http://schema.org/Text", String.class) { + @Override + public T parse(Object value, Class clazz) { + return clazz.cast(value); + } + }, + BOOLEAN("http://schema.org/Boolean", Boolean.class) { + @Override + public T parse(Object value, Class clazz) { + return clazz.cast(value); + } + }, + INTEGER("http://schema.org/Integer", Long.class, Integer.class) { + @Override + public T parse(Object value, Class clazz) { + Object result = null; + if (value.getClass() == clazz) { + result = value; + } else if (clazz == Long.class) { + if (value instanceof Integer) { + result = ((Integer) value).longValue(); + } else if (value instanceof String) { + result = Long.valueOf((String) value); + } + } else if (clazz == Integer.class) { + if (value instanceof Long) { + result = ((Long) value).intValue(); + } else if (value instanceof String) { + result = Integer.valueOf((String) value); + } + } + return clazz.cast(result); + } + }, + FLOAT("http://schema.org/Float", Double.class, Float.class) { + @Override + public T parse(Object value, Class clazz) { + Object result = null; + if (value.getClass() == clazz) { + result = value; + } else if (clazz == Double.class) { + if (value instanceof Float) { + result = ((Float) value).doubleValue(); + } else if (value instanceof String) { + result = Double.valueOf((String) value); + } + } else if (clazz == Float.class) { + if (value instanceof Double) { + result = ((Double) value).floatValue(); + } else if (value instanceof String) { + result = Float.valueOf((String) value); + } + } + return clazz.cast(result); + } + }; + + private final String uri; + private final Class[] classes; + + private CloudKnownType(String uri, Class... classes) { + this.uri = uri; + this.classes = classes; + } + + public String getUri() { + return uri; + } + + public abstract T parse(Object value, Class clazz); + + public Class defaultClass() { + return classes[0]; + } + + private static final Map typesByUri = + Collections.unmodifiableMap(buildTypesByUri()); + + private static Map buildTypesByUri() { + Map result = new HashMap<>(); + for (CloudKnownType ty : CloudKnownType.values()) { + result.put(ty.getUri(), ty); + } + return result; + } + + @Nullable + public static CloudKnownType forUri(@Nullable String uri) { + if (uri == null) { + return null; + } + return typesByUri.get(uri); + } + + private static final Map, CloudKnownType> typesByClass = + Collections.unmodifiableMap(buildTypesByClass()); + + private static Map, CloudKnownType> buildTypesByClass() { + Map, CloudKnownType> result = new HashMap<>(); + for (CloudKnownType ty : CloudKnownType.values()) { + for (Class clazz : ty.classes) { + result.put(clazz, ty); + } + } + return result; + } + + @Nullable + public static CloudKnownType forClass(Class clazz) { + return typesByClass.get(clazz); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtils.java new file mode 100644 index 000000000000..da99e5b3c385 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtils.java @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.util.common.Metric; +import com.google.cloud.dataflow.sdk.util.common.Metric.DoubleMetric; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utilities for working with Dataflow API Metrics. + */ +public class CloudMetricUtils { + // Do not instantiate. + private CloudMetricUtils() {} + + /** + * Returns a List of {@link MetricUpdate}s representing the given Metrics. + */ + public static List extractCloudMetrics( + Collection> metrics, + String workerId) { + List cloudMetrics = new ArrayList<>(metrics.size()); + for (Metric metric : metrics) { + cloudMetrics.add(extractCloudMetric(metric, workerId)); + } + return cloudMetrics; + } + + /** + * Returns a {@link MetricUpdate} representing the given Metric. + */ + public static MetricUpdate extractCloudMetric(Metric metric, String workerId) { + if (metric instanceof DoubleMetric) { + return extractCloudMetric( + metric, + ((DoubleMetric) metric).getValue(), + workerId); + } else { + throw new IllegalArgumentException("unexpected kind of Metric"); + } + } + + private static MetricUpdate extractCloudMetric( + Metric metric, Double value, String workerId) { + MetricStructuredName name = new MetricStructuredName(); + name.setName(metric.getName()); + Map context = new HashMap<>(); + context.put("workerId", workerId); + name.setContext(context); + return new MetricUpdate().setName(name).setScalar(CloudObject.forFloat(value)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java new file mode 100644 index 000000000000..973fe5ab7707 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A representation of an arbitrary Java object to be instantiated by Dataflow + * workers. + *

+ * Typically, an object to be written by the SDK to the Dataflow service will + * implement a method (typically called {@code asCloudObject()}) which returns a + * {@code CloudObject} to represent the object in the protocol. Once the + * {@code CloudObject} is constructed, the method should explicitly add + * additional properties to be presented during deserialization, representing + * child objects by building additional {@code CloudObject}s. + */ +public final class CloudObject extends GenericJson { + /** + * Constructs a {@code CloudObject} by copying the supplied serialized object spec, + * which must represent an SDK object serialized for transport via the + * Dataflow API. + *

+ * The most common use of this method is during deserialization on the worker, + * where it's used as a binding type during instance construction. + * + * @param spec supplies the serialized form of the object as a nested map + * @throws RuntimeException if the supplied map does not represent an SDK object + */ + public static CloudObject fromSpec(Map spec) { + CloudObject result = new CloudObject(); + result.putAll(spec); + if (result.className == null) { + throw new RuntimeException("Unable to create an SDK object from " + spec + + ": Object class not specified (missing \"" + + PropertyNames.OBJECT_TYPE_NAME + "\" field)"); + } + return result; + } + + /** + * Constructs a {@code CloudObject} to be used for serializing an instance of + * the supplied class for transport via the Dataflow API. The instance + * parameters to be serialized must be supplied explicitly after the + * {@code CloudObject} is created, by using {@link CloudObject#put}. + * + * @param cls the class to use when deserializing the object on the worker + */ + public static CloudObject forClass(Class cls) { + CloudObject result = new CloudObject(); + result.className = checkNotNull(cls).getName(); + return result; + } + + /** + * Constructs a {@code CloudObject} to be used for serializing data to be + * deserialized using the supplied class name the supplied class name for + * transport via the Dataflow API. The instance parameters to be serialized + * must be supplied explicitly after the {@code CloudObject} is created, by + * using {@link CloudObject#put}. + * + * @param className the class to use when deserializing the object on the worker + */ + public static CloudObject forClassName(String className) { + CloudObject result = new CloudObject(); + result.className = checkNotNull(className); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forString(String value) { + CloudObject result = forClassName(CloudKnownType.TEXT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forBoolean(Boolean value) { + CloudObject result = forClassName(CloudKnownType.BOOLEAN.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forInteger(Long value) { + CloudObject result = forClassName(CloudKnownType.INTEGER.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forInteger(Integer value) { + CloudObject result = forClassName(CloudKnownType.INTEGER.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forFloat(Float value) { + CloudObject result = forClassName(CloudKnownType.FLOAT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forFloat(Double value) { + CloudObject result = forClassName(CloudKnownType.FLOAT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value of a + * well-known cloud object type. + * @param value the scalar value to represent. + * @throw RuntimeException if the value does not have a {@link CloudKnownType} + * mapping + */ + public static CloudObject forKnownType(Object value) { + @Nullable CloudKnownType ty = CloudKnownType.forClass(value.getClass()); + if (ty == null) { + throw new RuntimeException("Unable to represent value via the Dataflow API: " + value); + } + CloudObject result = forClassName(ty.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + @Key(PropertyNames.OBJECT_TYPE_NAME) + private String className; + + private CloudObject() {} + + /** + * Gets the name of the Java class which this CloudObject represents. + */ + public String getClassName() { + return className; + } + + @Override + public CloudObject clone() { + return (CloudObject) super.clone(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtils.java new file mode 100644 index 000000000000..7d97948af437 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtils.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.runners.worker.SourceFactory; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utilities for working with Source Dataflow API definitions and {@link Source} + * objects. + */ +public class CloudSourceUtils { + /** + * Returns a copy of the source with {@code baseSpecs} flattened into {@code spec}. + * On conflict for a parameter name, values in {@code spec} override values in {@code baseSpecs}, + * and later values in {@code baseSpecs} override earlier ones. + */ + public static com.google.api.services.dataflow.model.Source + flattenBaseSpecs(com.google.api.services.dataflow.model.Source source) { + if (source.getBaseSpecs() == null) { + return source; + } + Map params = new HashMap<>(); + for (Map baseSpec : source.getBaseSpecs()) { + params.putAll(baseSpec); + } + params.putAll(source.getSpec()); + + com.google.api.services.dataflow.model.Source result = source.clone(); + result.setSpec(params); + result.setBaseSpecs(null); + return result; + } + + /** Reads all elements from the given {@link Source}. */ + public static List readElemsFromSource(Source source) { + List elems = new ArrayList<>(); + try (Source.SourceIterator it = source.iterator()) { + while (it.hasNext()) { + elems.add(it.next()); + } + } catch (IOException e) { + throw new RuntimeException("Failed to read from source: " + source, e); + } + return elems; + } + + /** + * Creates a {@link Source} from the given Dataflow Source API definition and + * reads all elements from it. + */ + public static List readElemsFromSource( + com.google.api.services.dataflow.model.Source source) { + try { + return readElemsFromSource(SourceFactory.create(null, source, null)); + } catch (Exception e) { + throw new RuntimeException("Failed to read from source: " + source.toString(), e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java new file mode 100644 index 000000000000..c77f35a45da2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java @@ -0,0 +1,202 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoderBase; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoderBase; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonTypeInfo.As; +import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.impl.TypeIdResolverBase; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.type.TypeFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.TypeVariable; + +/** + * Utilities for working with Coders. + */ +public final class CoderUtils { + private CoderUtils() {} // Non-instantiable + + /** + * Coder class-name alias for a key-value type. + */ + public static final String KIND_PAIR = "kind:pair"; + + /** + * Coder class-name alias for a stream type. + */ + public static final String KIND_STREAM = "kind:stream"; + + /** + * Encodes the given value using the specified Coder, and returns + * the encoded bytes. + * + * @throws CoderException if there are errors during encoding + */ + public static byte[] encodeToByteArray(Coder coder, T value) + throws CoderException { + try { + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + coder.encode(value, os, Coder.Context.OUTER); + return os.toByteArray(); + } + } catch (IOException exn) { + throw new RuntimeException("unexpected IOException", exn); + } + } + + /** + * Decodes the given bytes using the specified Coder, and returns + * the resulting decoded value. + * + * @throws CoderException if there are errors during decoding + */ + public static T decodeFromByteArray(Coder coder, byte[] encodedValue) + throws CoderException { + try { + try (ByteArrayInputStream is = new ByteArrayInputStream(encodedValue)) { + T result = coder.decode(is, Coder.Context.OUTER); + if (is.available() != 0) { + throw new CoderException( + is.available() + " unexpected extra bytes after decoding " + + result); + } + return result; + } + } catch (IOException exn) { + throw new RuntimeException("unexpected IOException", exn); + } + } + + public static CloudObject makeCloudEncoding( + String type, + CloudObject... componentSpecs) { + CloudObject encoding = CloudObject.forClassName(type); + if (componentSpecs.length > 0) { + addList(encoding, PropertyNames.COMPONENT_ENCODINGS, componentSpecs); + } + return encoding; + } + + /** + * A {@link com.fasterxml.jackson.databind.module.Module} which adds the type + * resolver needed for Coder definitions created by the Dataflow service. + */ + static final class Jackson2Module extends SimpleModule { + /** + * The Coder custom type resolver. + *

+ * This resolver resolves coders. If the Coder ID is a particular + * well-known identifier supplied by the Dataflow service, it's replaced + * with the corresponding class. All other Coder instances are resolved + * by class name, using the package com.google.cloud.dataflow.sdk.coders + * if there are no "."s in the ID. + */ + private static final class Resolver extends TypeIdResolverBase { + public Resolver() { + super(TypeFactory.defaultInstance().constructType(Coder.class), + TypeFactory.defaultInstance()); + } + + @Override + public JavaType typeFromId(String id) { + Class clazz = getClassForId(id); + if (clazz == KvCoder.class) { + clazz = KvCoderBase.class; + } + if (clazz == MapCoder.class) { + clazz = MapCoderBase.class; + } + TypeVariable[] tvs = clazz.getTypeParameters(); + JavaType[] types = new JavaType[tvs.length]; + for (int lupe = 0; lupe < tvs.length; lupe++) { + types[lupe] = TypeFactory.unknownType(); + } + return _typeFactory.constructSimpleType(clazz, types); + } + + private Class getClassForId(String id) { + try { + if (id.contains(".")) { + return Class.forName(id); + } + + if (id.equals(KIND_STREAM)) { + return IterableCoder.class; + } else if (id.equals(KIND_PAIR)) { + return KvCoder.class; + } + + // Otherwise, see if the ID is the name of a class in + // com.google.cloud.dataflow.sdk.coders. We do this via creating + // the class object so that class loaders have a chance to get + // involved -- and since we need the class object anyway. + return Class.forName("com.google.cloud.dataflow.sdk.coders." + id); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Unable to convert coder ID " + id + " to class", e); + } + } + + @Override + public String idFromValueAndType(Object o, Class clazz) { + return clazz.getName(); + } + + @Override + public String idFromValue(Object o) { + return o.getClass().getName(); + } + + @Override + public JsonTypeInfo.Id getMechanism() { + return JsonTypeInfo.Id.CUSTOM; + } + } + + /** + * The mixin class defining how Coders are handled by the deserialization + * {@link ObjectMapper}. + *

+ * This is done via a mixin so that this resolver is only used + * during deserialization requested by the Dataflow SDK. + */ + @JsonTypeIdResolver(Resolver.class) + @JsonTypeInfo(use = Id.CUSTOM, include = As.PROPERTY, property = PropertyNames.OBJECT_TYPE_NAME) + private static final class Mixin {} + + public Jackson2Module() { + super("DataflowCoders"); + setMixInAnnotation(Coder.class, Mixin.class); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java new file mode 100644 index 000000000000..2a24a76fde9f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.extensions.java6.auth.oauth2.AbstractPromptReceiver; +import com.google.api.client.extensions.java6.auth.oauth2.AuthorizationCodeInstalledApp; +import com.google.api.client.googleapis.auth.oauth2.GoogleAuthorizationCodeFlow; +import com.google.api.client.googleapis.auth.oauth2.GoogleClientSecrets; +import com.google.api.client.googleapis.auth.oauth2.GoogleCredential; +import com.google.api.client.googleapis.auth.oauth2.GoogleOAuthConstants; +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.util.Preconditions; +import com.google.api.client.util.Strings; +import com.google.api.client.util.store.FileDataStoreFactory; +import com.google.cloud.dataflow.sdk.options.GcpOptions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * Provides support for loading credentials. + */ +public class Credentials { + + private static final Logger LOG = LoggerFactory.getLogger(Credentials.class); + + /** OAuth 2.0 scopes used by a local worker (not on GCE). + * The scope cloud-platform provides access to all Cloud Platform resources. + * cloud-platform isn't sufficient yet for talking to datastore so we request + * those resources separately. + * + * Note that trusted scope relationships don't apply to OAuth tokens, so for + * services we access directly (GCS) as opposed to through the backend + * (BigQuery, GCE), we need to explicitly request that scope. + */ + private static final List WORKER_SCOPES = Arrays.asList( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/datastore"); + + private static final List USER_SCOPES = Arrays.asList( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/datastore"); + + private static class PromptReceiver extends AbstractPromptReceiver { + @Override + public String getRedirectUri() { + return GoogleOAuthConstants.OOB_REDIRECT_URI; + } + } + + /** + * Initializes OAuth2 credential for a worker, using the + * + * application default credentials, or from a local key file when running outside of GCE. + */ + public static Credential getWorkerCredential(GcpOptions options) + throws IOException { + String keyFile = options.getServiceAccountKeyfile(); + String accountName = options.getServiceAccountName(); + + if (keyFile != null && accountName != null) { + try { + return getCredentialFromFile(keyFile, accountName, WORKER_SCOPES); + } catch (GeneralSecurityException e) { + LOG.warn("Unable to obtain credentials from file {}", keyFile); + // Fall through.. + } + } + + return GoogleCredential.getApplicationDefault().createScoped(WORKER_SCOPES); + } + + /** + * Initializes OAuth2 credential for an interactive user program. + * + * This can use 4 different mechanisms for obtaining a credential: + *

    + *
  1. + * It can fetch the + * + * application default credentials. + *
  2. + *
  3. + * It can run the gcloud tool in a subprocess to obtain a credential. + * This is the preferred mechanism. The property "gcloud_path" can be + * used to specify where we search for gcloud data. + *
  4. + *
  5. + * The user can specify a client secrets file and go through the OAuth2 + * webflow. The credential will then be cached in the user's home + * directory for reuse. Provide the property "secrets_file" to use this + * mechanism. + *
  6. + *
  7. + * The user can specify a file containing a service account. + * Provide the properties "service_account_keyfile" and + * "service_account_name" to use this mechanism. + *
  8. + *
+ * The default mechanism is to use the + * + * application default credentials falling back to gcloud. The other options can be + * used by providing the corresponding properties. + */ + public static Credential getUserCredential(GcpOptions options) + throws IOException, GeneralSecurityException { + String keyFile = options.getServiceAccountKeyfile(); + String accountName = options.getServiceAccountName(); + + if (keyFile != null && accountName != null) { + try { + return getCredentialFromFile(keyFile, accountName, USER_SCOPES); + } catch (GeneralSecurityException e) { + throw new IOException("Unable to obtain credentials from file", e); + } + } + + if (options.getSecretsFile() != null) { + return getCredentialFromClientSecrets(options, USER_SCOPES); + } + + try { + return GoogleCredential.getApplicationDefault().createScoped(USER_SCOPES); + } catch (IOException e) { + LOG.info("Failed to get application default credentials, falling back to gcloud."); + } + + String gcloudPath = options.getGCloudPath(); + return getCredentialFromGCloud(gcloudPath); + } + + /** + * Loads OAuth2 credential from a local file. + */ + private static Credential getCredentialFromFile( + String keyFile, String accountId, Collection scopes) + throws IOException, GeneralSecurityException { + GoogleCredential credential = new GoogleCredential.Builder() + .setTransport(Transport.getTransport()) + .setJsonFactory(Transport.getJsonFactory()) + .setServiceAccountId(accountId) + .setServiceAccountScopes(scopes) + .setServiceAccountPrivateKeyFromP12File(new File(keyFile)) + .build(); + + LOG.info("Created credential from file {}", keyFile); + return credential; + } + + /** + * Loads OAuth2 credential from GCloud utility. + */ + private static Credential getCredentialFromGCloud(String gcloudPath) + throws IOException, GeneralSecurityException { + GCloudCredential credential; + HttpTransport transport = GoogleNetHttpTransport.newTrustedTransport(); + if (Strings.isNullOrEmpty(gcloudPath)) { + credential = new GCloudCredential(transport); + } else { + credential = new GCloudCredential(gcloudPath, transport); + } + + try { + credential.refreshToken(); + } catch (IOException e) { + throw new RuntimeException("Could not obtain credential using gcloud", e); + } + + LOG.info("Got credential from GCloud"); + return credential; + } + + /** + * Loads OAuth2 credential from client secrets, which may require an + * interactive authorization prompt. + */ + private static Credential getCredentialFromClientSecrets( + GcpOptions options, Collection scopes) + throws IOException, GeneralSecurityException { + String clientSecretsFile = options.getSecretsFile(); + + Preconditions.checkArgument(clientSecretsFile != null); + HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport(); + + JsonFactory jsonFactory = JacksonFactory.getDefaultInstance(); + GoogleClientSecrets clientSecrets; + + try { + clientSecrets = GoogleClientSecrets.load(jsonFactory, + new FileReader(clientSecretsFile)); + } catch (IOException e) { + throw new RuntimeException( + "Could not read the client secrets from file: " + clientSecretsFile, + e); + } + + FileDataStoreFactory dataStoreFactory = + new FileDataStoreFactory(new java.io.File(options.getCredentialDir())); + + GoogleAuthorizationCodeFlow flow = new GoogleAuthorizationCodeFlow.Builder( + httpTransport, jsonFactory, clientSecrets, scopes) + .setDataStoreFactory(dataStoreFactory) + .build(); + + // The credentialId identifies the credential if we're using a persistent + // credential store. + Credential credential = + new AuthorizationCodeInstalledApp(flow, new PromptReceiver()) + .authorize(options.getCredentialId()); + + LOG.info("Got credential from client secret"); + return credential; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java new file mode 100644 index 000000000000..ab7e0de6a8e0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +/** + * Utilities for working with the Dataflow distribution. + */ +public final class DataflowReleaseInfo extends GenericJson { + private static final Logger LOG = LoggerFactory.getLogger(DataflowReleaseInfo.class); + + private static final String DATAFLOW_PROPERTIES_PATH = + "/com/google/cloud/dataflow/sdk/sdk.properties"; + + private static class LazyInit { + private static final DataflowReleaseInfo INSTANCE = + new DataflowReleaseInfo(DATAFLOW_PROPERTIES_PATH); + } + + /** + * Returns an instance of DataflowReleaseInfo. + */ + public static DataflowReleaseInfo getReleaseInfo() { + return LazyInit.INSTANCE; + } + + @Key private String name = "Google Cloud Dataflow Java SDK"; + @Key private String version = "Unknown"; + + /** Provides the SDK name. */ + public String getName() { + return name; + } + + /** Provides the SDK version. */ + public String getVersion() { + return version; + } + + private DataflowReleaseInfo(String resourcePath) { + Properties properties = new Properties(); + + InputStream in = DataflowReleaseInfo.class.getResourceAsStream( + DATAFLOW_PROPERTIES_PATH); + if (in == null) { + LOG.warn("Dataflow properties resource not found: {}", resourcePath); + return; + } + + try { + properties.load(in); + } catch (IOException e) { + LOG.warn("Error loading Dataflow properties resource: ", e); + } + + for (String name : properties.stringPropertyNames()) { + if (name.equals("name")) { + // We don't allow the properties to override the SDK name. + continue; + } + put(name, properties.getProperty(name)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java new file mode 100644 index 000000000000..a157ceefa57c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.ValueWithMetadata; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * {@link ExecutionContext} for use in direct mode. + */ +public class DirectModeExecutionContext extends BatchModeExecutionContext { + List output = new ArrayList<>(); + Map, List> sideOutputs = new HashMap<>(); + + @Override + public ExecutionContext.StepContext createStepContext(String stepName) { + return new StepContext(stepName); + } + + @Override + public void noteOutput(WindowedValue outputElem) { + output.add(ValueWithMetadata.of(outputElem) + .withKey(getKey())); + } + + @Override + public void noteSideOutput(TupleTag tag, WindowedValue outputElem) { + List output = sideOutputs.get(tag); + if (output == null) { + output = new ArrayList<>(); + sideOutputs.put(tag, output); + } + output.add(ValueWithMetadata.of(outputElem) + .withKey(getKey())); + } + + public List> getOutput(TupleTag tag) { + return (List) output; + } + + public List> getSideOutput(TupleTag tag) { + if (sideOutputs.containsKey(tag)) { + return (List) sideOutputs.get(tag); + } else { + return new ArrayList<>(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnContext.java new file mode 100644 index 000000000000..80d8f34edd04 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnContext.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.DoFnRunner.OutputManager; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A concrete implementation of {@link DoFn.Context} used for running + * a {@link DoFn}. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + * @param the type of object which receives outputs + */ +class DoFnContext extends DoFn.Context { + private static final int MAX_SIDE_OUTPUTS = 1000; + + final PipelineOptions options; + final DoFn fn; + final PTuple sideInputs; + final OutputManager outputManager; + final Map outputMap; + final TupleTag mainOutputTag; + final StepContext stepContext; + final CounterSet.AddCounterMutator addCounterMutator; + + public DoFnContext(PipelineOptions options, + DoFn fn, + PTuple sideInputs, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator) { + fn.super(); + this.options = options; + this.fn = fn; + this.sideInputs = sideInputs; + this.outputManager = outputManager; + this.mainOutputTag = mainOutputTag; + this.outputMap = new HashMap<>(); + outputMap.put(mainOutputTag, outputManager.initialize(mainOutputTag)); + for (TupleTag sideOutputTag : sideOutputTags) { + outputMap.put(sideOutputTag, outputManager.initialize(sideOutputTag)); + } + this.stepContext = stepContext; + this.addCounterMutator = addCounterMutator; + } + + public R getReceiver(TupleTag tag) { + R receiver = outputMap.get(tag); + if (receiver == null) { + throw new IllegalArgumentException( + "calling getReceiver() with unknown tag " + tag); + } + return receiver; + } + + ////////////////////////////////////////////////////////////////////////////// + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + TupleTag tag = view.getTagInternal(); + if (!sideInputs.has(tag)) { + throw new IllegalArgumentException( + "calling sideInput() with unknown view; " + + "did you forget to pass the view in " + + "ParDo.withSideInputs()?"); + } + return view.fromIterableInternal((Iterable>) sideInputs.get(tag)); + } + + void outputWindowedValue( + O output, + Instant timestamp, + Collection windows) { + WindowedValue windowedElem = WindowedValue.of(output, timestamp, windows); + outputManager.output(outputMap.get(mainOutputTag), windowedElem); + if (stepContext != null) { + stepContext.noteOutput(windowedElem); + } + } + + protected void sideOutputWindowedValue(TupleTag tag, + T output, + Instant timestamp, + Collection windows) { + R receiver = outputMap.get(tag); + if (receiver == null) { + // This tag wasn't declared nor was it seen before during this execution. + // Thus, this must be a new, undeclared and unconsumed output. + + // To prevent likely user errors, enforce the limit on the number of side + // outputs. + if (outputMap.size() >= MAX_SIDE_OUTPUTS) { + throw new IllegalArgumentException( + "the number of side outputs has exceeded a limit of " + + MAX_SIDE_OUTPUTS); + } + + // Register the new TupleTag with outputManager and add an entry for it in + // the outputMap. + receiver = outputManager.initialize(tag); + outputMap.put(tag, receiver); + } + + WindowedValue windowedElem = WindowedValue.of(output, timestamp, windows); + outputManager.output(receiver, windowedElem); + if (stepContext != null) { + stepContext.noteSideOutput(tag, windowedElem); + } + } + + // Following implementations of output, outputWithTimestamp, and sideOutput + // are only accessible in DoFn.startBundle and DoFn.finishBundle, and will be shadowed by + // ProcessContext's versions in DoFn.processElement. + // TODO: it seems wrong to use Long.MIN_VALUE, since it will violate all our rules about + // DoFns preserving watermarks. + @Override + public void output(O output) { + outputWindowedValue(output, + new Instant(Long.MIN_VALUE), + Arrays.asList(GlobalWindow.Window.INSTANCE)); + } + + @Override + public void outputWithTimestamp(O output, Instant timestamp) { + outputWindowedValue(output, timestamp, Arrays.asList(GlobalWindow.Window.INSTANCE)); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + sideOutputWindowedValue(tag, + output, + new Instant(Long.MIN_VALUE), + Arrays.asList(GlobalWindow.Window.INSTANCE)); + } + + private String generateInternalAggregatorName(String userName) { + return "user-" + stepContext.getStepName() + "-" + userName; + } + + @Override + public Aggregator createAggregator( + String name, Combine.CombineFn combiner) { + return new AggregatorImpl<>(generateInternalAggregatorName(name), combiner, addCounterMutator); + } + + @Override + public Aggregator createAggregator( + String name, SerializableFunction, AO> combiner) { + return new AggregatorImpl, AO>( + generateInternalAggregatorName(name), combiner, addCounterMutator); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnProcessContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnProcessContext.java new file mode 100644 index 000000000000..d393e6f0b8b6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnProcessContext.java @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.KeyedState; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresKeyedState; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * A concrete implementation of {@link DoFn.ProcessContext} used for running + * a {@link DoFn} over a single element. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + */ +class DoFnProcessContext extends DoFn.ProcessContext { + + final DoFn fn; + final DoFnContext context; + final WindowedValue windowedValue; + + public DoFnProcessContext(DoFn fn, + DoFnContext context, + WindowedValue windowedValue) { + fn.super(); + this.fn = fn; + this.context = context; + this.windowedValue = windowedValue; + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public I element() { + return windowedValue.getValue(); + } + + @Override + public KeyedState keyedState() { + if (!(fn instanceof RequiresKeyedState) + || (element() != null && !(element() instanceof KV))) { + throw new UnsupportedOperationException( + "Keyed state is only available in the context of a keyed DoFn marked as requiring state"); + } + + return context.stepContext; + } + + @Override + public T sideInput(PCollectionView view) { + return context.sideInput(view); + } + + @Override + public void output(O output) { + context.outputWindowedValue(output, windowedValue.getTimestamp(), windowedValue.getWindows()); + } + + @Override + public void outputWithTimestamp(O output, Instant timestamp) { + Instant originalTimestamp = windowedValue.getTimestamp(); + + if (originalTimestamp != null) { + Preconditions.checkArgument( + !timestamp.isBefore(originalTimestamp.minus(fn.getAllowedTimestampSkew()))); + } + context.outputWindowedValue(output, timestamp, windowedValue.getWindows()); + } + + void outputWindowedValue( + O output, + Instant timestamp, + Collection windows) { + context.outputWindowedValue(output, timestamp, windows); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + context.sideOutputWindowedValue(tag, + output, + windowedValue.getTimestamp(), + windowedValue.getWindows()); + } + + @Override + public Aggregator createAggregator( + String name, Combine.CombineFn combiner) { + return context.createAggregator(name, combiner); + } + + @Override + public Aggregator createAggregator( + String name, SerializableFunction, AO> combiner) { + return context.createAggregator(name, combiner); + } + + @Override + public Instant timestamp() { + return windowedValue.getTimestamp(); + } + + @Override + public Collection windows() { + return windowedValue.getWindows(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java new file mode 100644 index 000000000000..975af472a4b6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.ArrayList; +import java.util.List; + +/** + * Runs a DoFn by constructing the appropriate contexts and passing them in. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + * @param the type of object which receives outputs + */ +public class DoFnRunner { + + /** Information about how to create output receivers and output to them. */ + public interface OutputManager { + + /** Returns the receiver to use for a given tag. */ + public R initialize(TupleTag tag); + + /** Outputs a single element to the provided receiver. */ + public void output(R receiver, WindowedValue output); + + } + + /** The DoFn being run. */ + public final DoFn fn; + + /** The context used for running the DoFn. */ + public final DoFnContext context; + + private DoFnRunner(PipelineOptions options, + DoFn fn, + PTuple sideInputs, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator) { + this.fn = fn; + this.context = new DoFnContext<>(options, fn, sideInputs, outputManager, + mainOutputTag, sideOutputTags, stepContext, + addCounterMutator); + } + + public static DoFnRunner create( + PipelineOptions options, + DoFn fn, + PTuple sideInputs, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator) { + return new DoFnRunner<>( + options, fn, sideInputs, outputManager, + mainOutputTag, sideOutputTags, stepContext, addCounterMutator); + } + + public static DoFnRunner createWithListOutputs( + PipelineOptions options, + DoFn fn, + PTuple sideInputs, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator) { + return create( + options, fn, sideInputs, + new OutputManager() { + @Override + public List initialize(TupleTag tag) { + return new ArrayList<>(); + } + @Override + public void output(List list, WindowedValue output) { + list.add(output); + } + }, + mainOutputTag, sideOutputTags, stepContext, addCounterMutator); + } + + /** Calls {@link DoFn#startBundle}. */ + public void startBundle() { + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.startBundle(context); + } catch (Throwable t) { + // Exception in user code. + throw new UserCodeException(t); + } + } + + /** + * Calls {@link DoFn#processElement} with a ProcessContext containing + * the current element. + */ + public void processElement(WindowedValue elem) { + DoFnProcessContext processContext = new DoFnProcessContext(fn, context, elem); + + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.processElement(processContext); + } catch (Throwable t) { + // Exception in user code. + throw new UserCodeException(t); + } + } + + /** Calls {@link DoFn#finishBundle}. */ + public void finishBundle() { + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.finishBundle(context); + } catch (Throwable t) { + // Exception in user code. + throw new UserCodeException(t); + } + } + + /** Returns the receiver who gets outputs with the provided tag. */ + public R getReceiver(TupleTag tag) { + return context.getReceiver(tag); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java new file mode 100644 index 000000000000..12d0745b67b6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; +import com.google.cloud.dataflow.sdk.values.CodedTupleTagMap; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Context about the current execution. This is guaranteed to exist during processing, + * but does not necessarily persist between different batches of work. + */ +public abstract class ExecutionContext { + private Map cachedStepContexts = new HashMap<>(); + + /** + * Returns the {@link StepContext} associated with the given step. + */ + public StepContext getStepContext(String stepName) { + StepContext context = cachedStepContexts.get(stepName); + if (context == null) { + context = createStepContext(stepName); + cachedStepContexts.put(stepName, context); + } + return context; + } + + /** + * Returns a collection view of all of the {@link StepContext}s. + */ + public Collection getAllStepContexts() { + return cachedStepContexts.values(); + } + + /** + * Implementations should override this to create the specific type + * of {@link StepContext} they neeed. + */ + public abstract StepContext createStepContext(String stepName); + + /** + * Writes out a timer to be fired when the watermark reaches the given + * timestamp. Timers are identified by their name, and can be moved + * by calling {@code setTimer} again, or deleted with + * {@link ExecutionContext#deleteTimer}. + */ + public abstract void setTimer(String timer, Instant timestamp); + + /** + * Deletes the given timer. + */ + public abstract void deleteTimer(String timer); + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#output} + * is called. + */ + public void noteOutput(WindowedValue output) {} + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#sideOutput} + * is called. + */ + public void noteSideOutput(TupleTag tag, WindowedValue output) {} + + /** + * Per-step, per-key context used for retrieving state. + */ + public abstract class StepContext implements DoFn.KeyedState { + private final String stepName; + + public StepContext(String stepName) { + this.stepName = stepName; + } + + public String getStepName() { + return stepName; + } + + public ExecutionContext getExecutionContext() { + return ExecutionContext.this; + } + + public void noteOutput(WindowedValue output) { + ExecutionContext.this.noteOutput(output); + } + + public void noteSideOutput(TupleTag tag, WindowedValue output) { + ExecutionContext.this.noteSideOutput(tag, output); + } + + /** + * Stores the provided value in per-{@link com.google.cloud.dataflow.sdk.transforms.DoFn}, + * per-key state. This state is in the form of a map from tags to arbitrary + * encodable values. + * + * @throws IOException if encoding the given value fails + */ + public abstract void store(CodedTupleTag tag, T value) throws IOException; + + /** + * Loads the values from the per-{@link com.google.cloud.dataflow.sdk.transforms.DoFn}, + * per-key state corresponding to the given tags. + * + * @throws IOException if decoding any of the requested values fails + */ + public abstract CodedTupleTagMap lookup(List> tags) + throws IOException; + + /** + * Loads the value from the per-{@link com.google.cloud.dataflow.sdk.transforms.DoFn}, + * per-key state corresponding to the given tag. + * + * @throws IOException if decoding the value fails + */ + public T lookup(CodedTupleTag tag) throws IOException { + return lookup(Arrays.asList(tag)).get(tag); + } + + /** + * Writes the provided value to the list of values in stored state corresponding to the + * provided tag. + * + * @throws IOException if encoding the given value fails + */ + public abstract void writeToTagList(CodedTupleTag tag, T value, Instant timestamp) + throws IOException; + + /** + * Deletes the list corresponding to the given tag. + */ + public abstract void deleteTagList(CodedTupleTag tag); + + /** + * Reads the elements of the list in stored state corresponding to the provided tag. + * + * @throws IOException if decoding any of the requested values fails + */ + public abstract Iterable readTagList(CodedTupleTag tag) + throws IOException; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java new file mode 100644 index 000000000000..71f66ed2f6db --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileFilter; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.file.FileSystems; +import java.nio.file.Files; +import java.nio.file.PathMatcher; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * Implements IOChannelFactory for local files. + */ +public class FileIOChannelFactory implements IOChannelFactory { + private static final Logger LOG = LoggerFactory.getLogger(FileIOChannelFactory.class); + + // This implementation only allows for wildcards in the file name. + // The directory portion must exist as-is. + @Override + public Collection match(String spec) throws IOException { + File file = new File(spec); + + File parent = file.getParentFile(); + if (!parent.exists()) { + throw new IOException("Unable to find parent directory of " + spec); + } + + final PathMatcher matcher = + FileSystems.getDefault().getPathMatcher("glob:" + spec); + File[] files = parent.listFiles(new FileFilter() { + @Override + public boolean accept(File pathname) { + return matcher.matches(pathname.toPath()); + } + }); + + List result = new LinkedList<>(); + for (File match : files) { + result.add(match.getPath()); + } + + return result; + } + + @Override + public ReadableByteChannel open(String spec) throws IOException { + LOG.debug("opening file {}", spec); + FileInputStream inputStream = new FileInputStream(spec); + return inputStream.getChannel(); + } + + @Override + public WritableByteChannel create(String spec, String mimeType) + throws IOException { + LOG.debug("creating file {}", spec); + return Channels.newChannel( + new BufferedOutputStream(new FileOutputStream(spec))); + } + + @Override + public long getSizeBytes(String spec) throws IOException { + return Files.size(FileSystems.getDefault().getPath(spec)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GCloudCredential.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GCloudCredential.java new file mode 100644 index 000000000000..a3a3fd2eb5bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GCloudCredential.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.BearerToken; +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.auth.oauth2.TokenResponse; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.util.IOUtils; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +/** + * A credential object which uses the GCloud command line tool to get + * an access token. + */ +public class GCloudCredential extends Credential { + private static final String DEFAULT_GCLOUD_BINARY = "gcloud"; + private final String binary; + + public GCloudCredential(HttpTransport transport) { + this(DEFAULT_GCLOUD_BINARY, transport); + } + + /** + * Path to the GCloud binary. + */ + public GCloudCredential(String binary, HttpTransport transport) { + super(new Builder(BearerToken.authorizationHeaderAccessMethod()) + .setTransport(transport)); + + this.binary = binary; + } + + private String readStream(InputStream stream) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + IOUtils.copy(stream, baos); + return baos.toString("UTF-8"); + } + + @Override + protected TokenResponse executeRefreshToken() throws IOException { + TokenResponse response = new TokenResponse(); + + ProcessBuilder builder = new ProcessBuilder(); + // ProcessBuilder will search the path automatically for the binary + // GCLOUD_BINARY. + builder.command(Arrays.asList(binary, "auth", "print-access-token")); + Process process = builder.start(); + + try { + process.waitFor(); + } catch (InterruptedException e) { + throw new RuntimeException( + "Could not obtain an access token using gcloud; timed out waiting " + + "for gcloud."); + } + + if (process.exitValue() != 0) { + String output; + try { + output = readStream(process.getErrorStream()); + } catch (IOException e) { + throw new RuntimeException( + "Could not obtain an access token using gcloud."); + } + + throw new RuntimeException( + "Could not obtain an access token using gcloud. Result of " + + "invoking gcloud was:\n" + output); + } + + String output; + try { + output = readStream(process.getInputStream()); + } catch (IOException e) { + throw new RuntimeException( + "Could not obtain an access token using gcloud. We encountered an " + + "an error trying to read stdout.", e); + } + String[] lines = output.split("\n"); + + if (lines.length != 1) { + throw new RuntimeException( + "Could not obtain an access token using gcloud. Result of " + + "invoking gcloud was:\n" + output); + } + + // Access token should be good for 5 minutes. + Long expiresInSeconds = 5L * 60; + response.setExpiresInSeconds(expiresInSeconds); + response.setAccessToken(output.trim()); + + return response; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java new file mode 100644 index 000000000000..9ff133261e60 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * Implements IOChannelFactory for GCS. + */ +public class GcsIOChannelFactory implements IOChannelFactory { + + private final GcsOptions options; + + public GcsIOChannelFactory(GcsOptions options) { + this.options = options; + } + + @Override + public Collection match(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + List matched = util.expand(path); + + List specs = new LinkedList<>(); + for (GcsPath match : matched) { + specs.add(match.toString()); + } + + return specs; + } + + @Override + public ReadableByteChannel open(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.open(path); + } + + @Override + public WritableByteChannel create(String spec, String mimeType) + throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.create(path, mimeType); + } + + @Override + public long getSizeBytes(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.fileSize(path); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java new file mode 100644 index 000000000000..c3edd2ac2c33 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java @@ -0,0 +1,277 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Preconditions; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.model.Objects; +import com.google.api.services.storage.model.StorageObject; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.util.gcsio.GoogleCloudStorageReadChannel; +import com.google.cloud.dataflow.sdk.util.gcsio.GoogleCloudStorageWriteChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Provides operations on GCS. + * + * TODO: re-implement as a FileSystemProvider? + */ +public class GcsUtil { + /** + * This is a {@link DefaultValueFactory} able to create a {@link GcsUtil} using + * any transport flags specified on the {@link PipelineOptions}. + */ + public static class GcsUtilFactory implements DefaultValueFactory { + /** + * Returns an instance of {@link GcsUtil} based on the + * {@link PipelineOptions}. + *

+ * If no instance has previously been created, one is created and the value + * stored in {@code options}. + */ + @Override + public GcsUtil create(PipelineOptions options) { + GcsOptions gcsOptions = options.as(GcsOptions.class); + LOG.debug("Creating new GcsUtil"); + return new GcsUtil(Transport.newStorageClient(gcsOptions).build(), + gcsOptions.getExecutorService()); + } + } + + private static final Logger LOG = LoggerFactory.getLogger(GcsUtil.class); + + /** Maximum number of items to retrieve per Objects.List request. */ + private static final long MAX_LIST_ITEMS_PER_CALL = 1024; + + /** Matches a glob containing a wildcard, capturing the portion before the first wildcard. */ + private static final Pattern GLOB_PREFIX = Pattern.compile("(?[^*?]*)[*?].*"); + + private static final String WILDCARD = "[\\[\\]*?]"; + private static final String NON_WILDCARD = "[^\\[\\]*?]"; + private static final String NON_DELIMITER = "[^/]"; + private static final String OPTIONAL_WILDCARD_AND_SUFFIX = "(" + WILDCARD + NON_DELIMITER + "*)?"; + + /** + * A {@link Pattern} that matches globs in which every wildcard is interpreted as such, + * assuming a delimiter of {@code '/'}. + * + *

Most importantly, if a {@code '*'} or {@code '?'} occurs before the + * final delimiter it will not be interpreted as a wildcard. + */ + public static final Pattern GCS_READ_PATTERN = Pattern.compile( + NON_WILDCARD + "*" + OPTIONAL_WILDCARD_AND_SUFFIX); + + ///////////////////////////////////////////////////////////////////////////// + + /** Client for the GCS API */ + private final Storage storage; + + // Helper delegate for turning IOExceptions from API calls into higher-level semantics. + private final ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + + // Exposed for testing. + final ExecutorService executorService; + + private GcsUtil(Storage storageClient, ExecutorService executorService) { + storage = storageClient; + this.executorService = executorService; + } + + /** + * Expands a pattern into matched paths. The input path may contain + * globs (in the last component only!), which are expanded in the result. + * + * TODO: add support for full path matching. + */ + public List expand(GcsPath path) throws IOException { + if (!GCS_READ_PATTERN.matcher(path.getObject()).matches()) { + throw new IllegalArgumentException( + "Unsupported wildcard usage in \"" + path + "\": " + + " all wildcards must occur after the final '/' delimiter."); + } + + Matcher m = GLOB_PREFIX.matcher(path.getObject()); + if (!m.matches()) { + return Arrays.asList(path); + } + + String prefix = m.group("PREFIX"); + Pattern p = Pattern.compile(globToRegexp(path.getObject())); + LOG.info("matching files in bucket {}, prefix {} against pattern {}", + path.getBucket(), prefix, p.toString()); + + Storage.Objects.List listObject = storage.objects().list(path.getBucket()); + listObject.setMaxResults(MAX_LIST_ITEMS_PER_CALL); + listObject.setDelimiter("/"); + listObject.setPrefix(prefix); + + String pageToken = null; + List results = new LinkedList<>(); + do { + if (pageToken != null) { + listObject.setPageToken(pageToken); + } + + Objects objects = listObject.execute(); + Preconditions.checkNotNull(objects); + + if (objects.getItems() == null) { + break; + } + + // Filter + for (StorageObject o : objects.getItems()) { + String name = o.getName(); + // Skip directories, which end with a slash. + if (p.matcher(name).matches() && !name.endsWith("/")) { + LOG.debug("Matched object: {}", name); + results.add(GcsPath.fromObject(o)); + } + } + + pageToken = objects.getNextPageToken(); + } while (pageToken != null); + + return results; + } + + /** + * Returns the file size from GCS, or -1 if the file does not exist. + */ + public long fileSize(GcsPath path) throws IOException { + try { + Storage.Objects.Get getObject = + storage.objects().get(path.getBucket(), path.getObject()); + + StorageObject object = getObject.execute(); + return object.getSize().longValue(); + } catch (IOException e) { + if (errorExtractor.itemNotFound(e)) { + return -1; + } + + // Re-throw any other error. + throw e; + } + } + + /** + * Opens an object in GCS. + * + * Returns a SeekableByteChannel which provides access to data in the bucket. + * + * @param path the GCS filename to read from + * @return a SeekableByteChannel which can read the object data + * @throws IOException + */ + public SeekableByteChannel open(GcsPath path) + throws IOException { + return new GoogleCloudStorageReadChannel(storage, path.getBucket(), + path.getObject(), errorExtractor); + } + + /** + * Creates an object in GCS. + * + * Returns a WritableByteChannel which can be used to write data to the + * object. + * + * @param path the GCS file to write to + * @param type the type of object, eg "text/plain". + * @return a Callable object which encloses the operation. + * @throws IOException + */ + public WritableByteChannel create(GcsPath path, + String type) throws IOException { + return new GoogleCloudStorageWriteChannel( + executorService, + storage, + path.getBucket(), + path.getObject(), + type); + } + + /** + * Expands glob expressions to regular expressions. + * + * @param globExp the glob expression to expand + * @return a string with the regular expression this glob expands to + */ + static String globToRegexp(String globExp) { + StringBuilder dst = new StringBuilder(); + char[] src = globExp.toCharArray(); + int i = 0; + while (i < src.length) { + char c = src[i++]; + switch (c) { + case '*': + dst.append("[^/]*"); + break; + case '?': + dst.append("[^/]"); + break; + case '.': + case '+': + case '{': + case '}': + case '(': + case ')': + case '|': + case '^': + case '$': + // These need to be escaped in regular expressions + dst.append('\\').append(c); + break; + case '\\': + i = doubleSlashes(dst, src, i); + break; + default: + dst.append(c); + break; + } + } + return dst.toString(); + } + + private static int doubleSlashes(StringBuilder dst, char[] src, int i) { + // Emit the next character without special interpretation + dst.append('\\'); + if ((i - 1) != src.length) { + dst.append(src[i]); + i++; + } else { + // A backslash at the very end is treated like an escaped backslash + dst.append('\\'); + } + return i; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java new file mode 100644 index 000000000000..62ae4875f965 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java @@ -0,0 +1,359 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.NonMergingWindowingFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.common.PeekingReiterator; +import com.google.cloud.dataflow.sdk.util.common.Reiterable; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.PriorityQueue; + +/** + * DoFn that merges windows and groups elements in those windows, optionally + * combining values. + * + * @param key type + * @param input value element type + * @param window type + */ +public class GroupAlsoByWindowsDoFn + extends DoFn>>, KV>> { + // TODO: Add back RequiresKeyed state once that is supported. + + protected WindowingFn windowingFn; + protected Coder inputCoder; + + public GroupAlsoByWindowsDoFn( + WindowingFn windowingFn, + Coder inputCoder) { + this.windowingFn = windowingFn; + this.inputCoder = inputCoder; + } + + @Override + public void processElement(ProcessContext processContext) throws Exception { + DoFnProcessContext>>, KV>> context = + (DoFnProcessContext>>, KV>>) processContext; + + if (windowingFn instanceof NonMergingWindowingFn) { + processElementViaIterators(context); + } else { + processElementViaWindowSet(context); + } + } + + private void processElementViaWindowSet( + DoFnProcessContext>>, KV>> context) + throws Exception { + + K key = context.element().getKey(); + BatchActiveWindowManager activeWindowManager = new BatchActiveWindowManager<>(); + AbstractWindowSet, W> windowSet = + new BufferingWindowSet(key, windowingFn, inputCoder, context, activeWindowManager); + + for (WindowedValue e : context.element().getValue()) { + for (BoundedWindow window : e.getWindows()) { + windowSet.put((W) window, e.getValue()); + } + ((WindowingFn) windowingFn) + .mergeWindows(new AbstractWindowSet.WindowMergeContext(windowSet, windowingFn)); + + maybeOutputWindows(activeWindowManager, windowSet, windowingFn, e.getTimestamp()); + } + + maybeOutputWindows(activeWindowManager, windowSet, windowingFn, null); + + windowSet.flush(); + } + + /** + * Outputs any windows that are complete, with their corresponding elemeents. + * If there are potentially complete windows, try merging windows first. + */ + private void maybeOutputWindows( + BatchActiveWindowManager activeWindowManager, + AbstractWindowSet windowSet, + WindowingFn windowingFn, + Instant nextTimestamp) throws Exception { + if (activeWindowManager.hasMoreWindows() + && (nextTimestamp == null + || activeWindowManager.nextTimestamp().isBefore(nextTimestamp))) { + // There is at least one window ready to emit. Merge now in case that window should be merged + // into a not yet completed one. + ((WindowingFn) windowingFn) + .mergeWindows(new AbstractWindowSet.WindowMergeContext(windowSet, windowingFn)); + } + + while (activeWindowManager.hasMoreWindows() + && (nextTimestamp == null + || activeWindowManager.nextTimestamp().isBefore(nextTimestamp))) { + W window = activeWindowManager.getWindow(); + if (windowSet.contains(window)) { + windowSet.markCompleted(window); + } + } + } + + private void processElementViaIterators( + DoFnProcessContext>>, KV>> context) + throws Exception { + K key = context.element().getKey(); + Iterable> value = context.element().getValue(); + PeekingReiterator> iterator; + + if (value instanceof Collection) { + iterator = new PeekingReiterator<>(new ListReiterator>( + new ArrayList>((Collection>) value), 0)); + } else if (value instanceof Reiterable) { + iterator = new PeekingReiterator(((Reiterable>) value).iterator()); + } else { + throw new IllegalArgumentException( + "Input to GroupAlsoByWindowsDoFn must be a Collection or Reiterable"); + } + + // This ListMultimap is a map of window maxTimestamps to the list of active + // windows with that maxTimestamp. + ListMultimap windows = ArrayListMultimap.create(); + + while (iterator.hasNext()) { + WindowedValue e = iterator.peek(); + for (BoundedWindow window : e.getWindows()) { + // If this window is not already in the active set, emit a new WindowReiterable + // corresponding to this window, starting at this element in the input Reiterable. + if (!windows.containsEntry(window.maxTimestamp(), window)) { + // Iterating through the WindowReiterable may advance iterator as an optimization + // for as long as it detects that there are no new windows. + windows.put(window.maxTimestamp(), window); + context.outputWindowedValue( + KV.of(key, (Iterable) new WindowReiterable(iterator, window)), + window.maxTimestamp(), + Arrays.asList((W) window)); + } + } + // Copy the iterator in case the next DoFn cached its version of the iterator instead + // of immediately iterating through it. + // And, only advance the iterator if the consuming operation hasn't done so. + iterator = iterator.copy(); + if (iterator.hasNext() && iterator.peek() == e) { + iterator.next(); + } + + // Remove all windows with maxTimestamp behind the current timestamp. + Iterator windowIterator = windows.keys().iterator(); + while (windowIterator.hasNext() + && windowIterator.next().isBefore(e.getTimestamp())) { + windowIterator.remove(); + } + } + } + + /** + * {@link Reiterable} representing a view of all elements in a base + * {@link Reiterator} that are in a given window. + */ + private static class WindowReiterable implements Reiterable { + private PeekingReiterator> baseIterator; + private BoundedWindow window; + + public WindowReiterable( + PeekingReiterator> baseIterator, BoundedWindow window) { + this.baseIterator = baseIterator; + this.window = window; + } + + @Override + public Reiterator iterator() { + // We don't copy the baseIterator when creating the first WindowReiterator + // so that the WindowReiterator can advance the baseIterator. We have to + // make a copy afterwards so that future calls to iterator() will start + // at the right spot. + Reiterator result = new WindowReiterator(baseIterator, window); + baseIterator = baseIterator.copy(); + return result; + } + + @Override + public String toString() { + StringBuilder result = new StringBuilder(); + result.append("WR{"); + for (V v : this) { + result.append(v.toString()).append(','); + } + result.append("}"); + return result.toString(); + } + } + + /** + * The {@link Reiterator} used by {@link WindowReiterable}. + */ + private static class WindowReiterator implements Reiterator { + private PeekingReiterator> iterator; + private BoundedWindow window; + + public WindowReiterator(PeekingReiterator> iterator, BoundedWindow window) { + this.iterator = iterator; + this.window = window; + } + + @Override + public Reiterator copy() { + return new WindowReiterator(iterator.copy(), window); + } + + @Override + public boolean hasNext() { + skipToValidElement(); + return (iterator.hasNext() && iterator.peek().getWindows().contains(window)); + } + + @Override + public V next() { + skipToValidElement(); + WindowedValue next = iterator.next(); + if (!next.getWindows().contains(window)) { + throw new NoSuchElementException("No next item in window"); + } + return next.getValue(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + /** + * Moves the underlying iterator forward until it either points to the next + * element in the correct window, or is past the end of the window. + */ + private void skipToValidElement() { + while (iterator.hasNext()) { + WindowedValue peek = iterator.peek(); + if (!peek.getTimestamp().isBefore(window.maxTimestamp())) { + // We are past the end of this window, so there can't be any more + // elements in this iterator. + break; + } + if (!(peek.getWindows().size() == 1 && peek.getWindows().contains(window))) { + // We have reached new windows; we need to copy the iterator so we don't + // keep advancing the outer loop in processElement. + iterator = iterator.copy(); + } + if (!peek.getWindows().contains(window)) { + // The next element is not in the right window: skip it. + iterator.next(); + } else { + // The next element is in the right window. + break; + } + } + } + } + + /** + * {@link Reiterator} that wraps a {@link List}. + */ + private static class ListReiterator implements Reiterator { + private List list; + private int index; + + public ListReiterator(List list, int index) { + this.list = list; + this.index = index; + } + + @Override + public T next() { + return list.get(index++); + } + + @Override + public boolean hasNext() { + return index < list.size(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public Reiterator copy() { + return new ListReiterator(list, index); + } + } + + private static class BatchActiveWindowManager + implements AbstractWindowSet.ActiveWindowManager { + // Sort the windows by their end timestamps so that we can efficiently + // ask for the next window that will be completed. + PriorityQueue windows = new PriorityQueue<>(11, new Comparator() { + @Override + public int compare(W w1, W w2) { + return w1.maxTimestamp().compareTo(w2.maxTimestamp()); + } + }); + + @Override + public void addWindow(W window) { + windows.add(window); + } + + @Override + public void removeWindow(W window) { + windows.remove(window); + } + + /** + * Returns whether there are more windows. + */ + public boolean hasMoreWindows() { + return windows.peek() != null; + } + + /** + * Returns the timestamp of the next window + */ + public Instant nextTimestamp() { + return windows.peek().maxTimestamp(); + } + + /** + * Returns and removes the next window. + */ + public W getWindow() { + return windows.poll(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java new file mode 100644 index 000000000000..683ca76efa5d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collection; + +/** + * Defines a factory for working with read and write channels. + * + * Channels provide an abstract API for IO operations. + * + * See FACTORY_MAP = + Collections.synchronizedMap(new HashMap()); + + // Pattern which matches shard placeholders within a shard template. + private static final Pattern SHARD_FORMAT_RE = Pattern.compile("(S+|N+)"); + + /** + * Associates a scheme with an {@link IOChannelFactory}. + * + * The given factory is used to construct read and write channels when + * a URI is provided with the given scheme. + * + * For example, when reading from "gs://bucket/path", the scheme "gs" is + * used to lookup the appropriate factory. + */ + public static void setIOFactory(String scheme, IOChannelFactory factory) { + FACTORY_MAP.put(scheme, factory); + } + + /** + * Registers standard factories globally. This requires {@link PipelineOptions} + * to provide e.g. credentials for GCS. + */ + public static void registerStandardIOFactories(PipelineOptions options) { + setIOFactory("gs", new GcsIOChannelFactory(options.as(GcsOptions.class))); + } + + /** + * Creates a write channel for the given filename. + */ + public static WritableByteChannel create(String filename, String mimeType) + throws IOException { + return getFactory(filename).create(filename, mimeType); + } + + /** + * Creates a write channel for the given file components. + * + *

If numShards is specified, then a ShardingWritableByteChannel is + * returned. + * + *

Shard numbers are 0 based, meaning they start with 0 and end at the + * number of shards - 1. + */ + public static WritableByteChannel create(String prefix, String shardTemplate, + String suffix, int numShards, String mimeType) throws IOException { + if (numShards == 1) { + return create(constructName(prefix, shardTemplate, suffix, 0, 1), + mimeType); + } + + ShardingWritableByteChannel shardingChannel = + new ShardingWritableByteChannel(); + + Set outputNames = new HashSet<>(); + for (int i = 0; i < numShards; i++) { + String outputName = + constructName(prefix, shardTemplate, suffix, i, numShards); + if (!outputNames.add(outputName)) { + throw new IllegalArgumentException( + "Shard name collision detected for: " + outputName); + } + WritableByteChannel channel = create(outputName, mimeType); + shardingChannel.addChannel(channel); + } + + return shardingChannel; + } + + /** + * Constructs a fully qualified name from components. + * + *

The name is built from a prefix, shard template (with shard numbers + * applied), and a suffix. All components are required, but may be empty + * strings. + * + *

Within a shard template, repeating sequences of the letters "S" or "N" + * are replaced with the shard number, or number of shards respectively. The + * numbers are formatted with leading zeros to match the length of the + * repeated sequence of letters. + * + *

For example, if prefix = "output", shardTemplate = "-SSS-of-NNN", and + * suffix = ".txt", with shardNum = 1 and numShards = 100, the following is + * produced: "output-001-of-100.txt". + */ + public static String constructName(String prefix, + String shardTemplate, String suffix, int shardNum, int numShards) { + // Matcher API works with StringBuffer, rather than StringBuilder. + StringBuffer sb = new StringBuffer(); + sb.append(prefix); + + Matcher m = SHARD_FORMAT_RE.matcher(shardTemplate); + while (m.find()) { + boolean isShardNum = (m.group(1).charAt(0) == 'S'); + + char[] zeros = new char[m.end() - m.start()]; + Arrays.fill(zeros, '0'); + DecimalFormat df = new DecimalFormat(String.valueOf(zeros)); + String formatted = df.format(isShardNum + ? shardNum + : numShards); + m.appendReplacement(sb, formatted); + } + m.appendTail(sb); + + sb.append(suffix); + return sb.toString(); + } + + private static final Pattern URI_SCHEME_PATTERN = Pattern.compile( + "(?[a-zA-Z][-a-zA-Z0-9+.]*)://.*"); + + /** + * Returns the IOChannelFactory associated with an input specification. + */ + public static IOChannelFactory getFactory(String spec) throws IOException { + // The spec is almost, but not quite, a URI. In particular, + // the reserved characters '[', ']', and '?' have meanings that differ + // from their use in the URI spec. ('*' is not reserved). + // Here, we just need the scheme, which is so circumscribed as to be + // very easy to extract with a regex. + Matcher matcher = URI_SCHEME_PATTERN.matcher(spec); + + if (!matcher.matches()) { + return new FileIOChannelFactory(); + } + + String scheme = matcher.group("scheme"); + IOChannelFactory ioFactory = FACTORY_MAP.get(scheme); + if (ioFactory != null) { + return ioFactory; + } + + throw new IOException("Unable to find handler for " + spec); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java new file mode 100644 index 000000000000..8712855a8622 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java @@ -0,0 +1,259 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Preconditions; +import com.google.common.reflect.TypeToken; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.LinkedList; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Utility for creating objects dynamically. + * + * @param type type of object returned by this instance builder + */ +public class InstanceBuilder { + + /** + * Create an InstanceBuilder for the given type. + *

+ * The specified type is the type returned by {@link #build}, which is + * typically the common base type or interface of the instance being + * constructed. + */ + public static InstanceBuilder ofType(Class type) { + return new InstanceBuilder<>(type); + } + + /** + * Create an InstanceBuilder for the given type. + *

+ * The specified type is the type returned by {@link #build}, which is + * typically the common base type or interface for the instance to be + * constructed. + *

+ * The TypeToken argument allows specification of generic types. For example, + * a {@code List} return type can be specified as + * {@code ofType(new TypeToken>(){})}. + */ + public static InstanceBuilder ofType(TypeToken token) { + @SuppressWarnings("unchecked") + Class type = (Class) token.getRawType(); + return new InstanceBuilder<>(type); + } + + /** + * Sets the class name to be constructed. + *

+ * If the name is a simple name (ie {@link Class#getSimpleName()}), then + * the package of the return type is added as a prefix. + *

+ * The default class is the return type, specified in {@link #ofType}. + *

+ * Modifies and returns the {@code InstanceBuilder} for chaining. + * + * @throws ClassNotFoundException if no class can be found by the given name + */ + public InstanceBuilder fromClassName(String name) + throws ClassNotFoundException { + Preconditions.checkArgument(factoryClass == null, + "Class name may only be specified once"); + if (name.indexOf('.') == -1) { + name = type.getPackage().getName() + "." + name; + } + + try { + factoryClass = Class.forName(name); + } catch (ClassNotFoundException e) { + throw new ClassNotFoundException( + String.format("Could not find class: %s", name), e); + } + return this; + } + + /** + * Sets the factory class to use for instance construction. + *

+ * Modifies and returns the {@code InstanceBuilder} for chaining. + */ + public InstanceBuilder fromClass(Class factoryClass) { + this.factoryClass = factoryClass; + return this; + } + + /** + * Sets the name of the factory method used to construct the instance. + *

+ * The default, if no factory method was specified, is to look for a class + * constructor. + *

+ * Modifies and returns the {@code InstanceBuilder} for chaining. + */ + public InstanceBuilder fromFactoryMethod(String methodName) { + Preconditions.checkArgument(this.methodName == null, + "Factory method name may only be specified once"); + this.methodName = methodName; + return this; + } + + /** + * Adds an argument to be passed to the factory method. + *

+ * The argument type is used to lookup the factory method. This type may be + * a supertype of the argument value's class. + *

+ * Modifies and returns the {@code InstanceBuilder} for chaining. + */ + public InstanceBuilder withArg(Class argType, A value) { + parameterTypes.add(argType); + arguments.add(value); + return this; + } + + /** + * Creates the instance by calling the factory method with the given + * arguments. + *

+ *

Defaults

+ *
    + *
  • factory class: defaults to the output type class, overridden + * via {@link #fromClassName(String)}. + *
  • factory method: defaults to using a constructor on the factory + * class, overridden via {@link #fromFactoryMethod(String)}. + *
+ * + * @throws RuntimeException if the method does not exist, on type mismatch, + * or if the method cannot be made accessible. + */ + public T build() { + if (factoryClass == null) { + factoryClass = type; + } + + Class[] types = parameterTypes + .toArray(new Class[parameterTypes.size()]); + + // TODO: cache results, to speed repeated type lookups? + if (methodName != null) { + return buildFromMethod(types); + } else { + return buildFromConstructor(types); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Type of object to construct. + */ + private final Class type; + + /** + * Types of parameters for Method lookup. + * + * @see Class#getDeclaredMethod(String, Class[]) + */ + private final List> parameterTypes = new LinkedList<>(); + + /** + * Arguments to factory method {@link Method#invoke(Object, Object...)}. + */ + private final List arguments = new LinkedList<>(); + + /** + * Name of factory method, or null to invoke the constructor. + */ + @Nullable private String methodName; + + /** + * Factory class, or null to instantiate {@code type}. + */ + @Nullable private Class factoryClass; + + private InstanceBuilder(Class type) { + this.type = type; + } + + private T buildFromMethod(Class[] types) { + Preconditions.checkState(factoryClass != null); + Preconditions.checkState(methodName != null); + + try { + Method method = factoryClass.getDeclaredMethod(methodName, types); + + Preconditions.checkState(Modifier.isStatic(method.getModifiers()), + "Factory method must be a static method for " + + factoryClass.getName() + "#" + method.getName() + ); + + Preconditions.checkState(type.isAssignableFrom(method.getReturnType()), + "Return type for " + factoryClass.getName() + "#" + method.getName() + + " must be assignable to " + type.getSimpleName()); + + if (!method.isAccessible()) { + method.setAccessible(true); + } + + Object[] args = arguments.toArray(new Object[arguments.size()]); + return type.cast(method.invoke(null, args)); + + } catch (NoSuchMethodException e) { + throw new RuntimeException("Unable to find factory method " + + factoryClass.getName() + "#" + methodName); + + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Failed to construct instance from " + + "factory method " + factoryClass.getName() + "#" + methodName, e); + } + } + + private T buildFromConstructor(Class[] types) { + Preconditions.checkState(factoryClass != null); + + try { + Constructor constructor = factoryClass.getDeclaredConstructor(types); + + Preconditions.checkState(type.isAssignableFrom(factoryClass), + "Instance type " + factoryClass.getName() + + " must be assignable to " + type.getSimpleName()); + + if (!constructor.isAccessible()) { + constructor.setAccessible(true); + } + + Object[] args = arguments.toArray(new Object[arguments.size()]); + return type.cast(constructor.newInstance(args)); + + } catch (NoSuchMethodException e) { + throw new RuntimeException("Unable to find constructor for " + + factoryClass.getName()); + + } catch (InvocationTargetException | + InstantiationException | + IllegalAccessException e) { + throw new RuntimeException("Failed to construct instance from " + + "constructor " + factoryClass.getName(), e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java new file mode 100644 index 000000000000..3318a150662a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** Constants representing various mime types. */ +public class MimeTypes { + public static final String TEXT = "text/plain"; + public static final String BINARY = "application/octet-stream"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java new file mode 100644 index 000000000000..89df25c39111 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java @@ -0,0 +1,230 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.Dataflow.V1b3.Projects.Jobs.Messages; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.ListJobMessagesResponse; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A helper class for monitoring jobs submitted to the service. + */ +public final class MonitoringUtil { + private String projectId; + private Messages messagesClient; + + /** Named constants for common values for the job state. */ + public static enum JobState { + UNKNOWN ("JOB_STATE_UNKNOWN", false), + STOPPED ("JOB_STATE_STOPPED", false), + RUNNING ("JOB_STATE_RUNNING", false), + DONE ("JOB_STATE_DONE", true), + FAILED ("JOB_STATE_FAILED", true), + CANCELLED("JOB_STATE_CANCELLED", true); + + private final String stateName; + private final boolean terminal; + + private JobState(String stateName, boolean terminal) { + this.stateName = stateName; + this.terminal = terminal; + } + + public final String getStateName() { + return stateName; + } + + public final boolean isTerminal() { + return terminal; + } + + private static final Map statesByName = + Collections.unmodifiableMap(buildStatesByName()); + + private static Map buildStatesByName() { + Map result = new HashMap<>(); + for (JobState state : JobState.values()) { + result.put(state.getStateName(), state); + } + return result; + } + + public static JobState toState(String stateName) { + @Nullable JobState state = statesByName.get(stateName); + if (state == null) { + state = UNKNOWN; + } + return state; + } + } + + /** + * An interface which can be used for defining callbacks to receive a list + * of JobMessages containing monitoring information. + */ + public interface JobMessagesHandler { + /** Process the rows. */ + void process(List messages); + } + + /** A handler which prints monitoring messages to a stream. */ + public static class PrintHandler implements JobMessagesHandler { + private PrintStream out; + + /** + * Construct the handler. + * + * @param stream The stream to write the messages to. + */ + public PrintHandler(PrintStream stream) { + out = stream; + } + + @Override + public void process(List messages) { + for (JobMessage message : messages) { + StringBuilder sb = new StringBuilder(); + if (message.getMessageText() != null && !message.getMessageText().isEmpty()) { + if (message.getMessageImportance() != null) { + if (message.getMessageImportance().equals("ERROR")) { + sb.append("Error: "); + } else if (message.getMessageImportance().equals("WARNING")) { + sb.append("Warning: "); + } + } + // TODO: Allow filtering out overly detailed messages. + sb.append(message.getMessageText()); + } + if (sb.length() > 0) { + @Nullable Instant time = fromCloudTime(message.getTime()); + if (time == null) { + out.print("UNKNOWN TIMESTAMP: "); + } else { + out.print(time + ": "); + } + out.println(sb.toString()); + } + } + out.flush(); + } + } + + /** Construct a helper for monitoring. */ + public MonitoringUtil(String projectId, Dataflow dataflow) { + this(projectId, dataflow.v1b3().projects().jobs().messages()); + } + + // @VisibleForTesting + MonitoringUtil(String projectId, Messages messagesClient) { + this.projectId = projectId; + this.messagesClient = messagesClient; + } + + /** + * Comparator for sorting rows in increasing order based on timestamp. + */ + public static class TimeStampComparator implements Comparator { + @Override + public int compare(JobMessage o1, JobMessage o2) { + @Nullable Instant t1 = fromCloudTime(o1.getTime()); + if (t1 == null) { + return -1; + } + @Nullable Instant t2 = fromCloudTime(o2.getTime()); + if (t2 == null) { + return 1; + } + return t1.compareTo(t2); + } + } + + /** + * Return job messages sorted in ascending order by timestamp. + * @param jobId The id of the job to get the messages for. + * @param startTimestampMs Return only those messages with a + * timestamp greater than this value. + * @return collection of messages + * @throws IOException + */ + public ArrayList getJobMessages( + String jobId, long startTimestampMs) throws IOException { + Instant startTimestamp = new Instant(startTimestampMs); + ArrayList allMessages = new ArrayList<>(); + String pageToken = null; + while (true) { + Messages.List listRequest = messagesClient.list(projectId, jobId); + if (pageToken != null) { + listRequest.setPageToken(pageToken); + } + ListJobMessagesResponse response = listRequest.execute(); + + if (response == null || response.getJobMessages() == null) { + return allMessages; + } + + for (JobMessage m : response.getJobMessages()) { + @Nullable Instant timestamp = fromCloudTime(m.getTime()); + if (timestamp == null) { + continue; + } + if (timestamp.isAfter(startTimestamp)) { + allMessages.add(m); + } + } + + if (response.getNextPageToken() == null) { + break; + } else { + pageToken = response.getNextPageToken(); + } + } + + Collections.sort(allMessages, new TimeStampComparator()); + return allMessages; + } + + public static String getJobMonitoringPageURL(String projectName, String jobId) { + try { + // Project name is allowed in place of the project id: the user will be redirected to a URL + // that has the project name replaced with project id. + return String.format( + "https://console.developers.google.com/project/%s/dataflow/job/%s", + URLEncoder.encode(projectName, "UTF-8"), + URLEncoder.encode(jobId, "UTF-8")); + } catch (UnsupportedEncodingException e) { + // Should never happen. + throw new AssertionError("UTF-8 encoding is not supported by the environment", e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java new file mode 100644 index 000000000000..eade03d25204 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +/** + * A representation used by {@link com.google.api.services.dataflow.model.Step}s + * to reference the output of other {@code Step}s. + */ +public final class OutputReference extends GenericJson { + @Key("@type") + public final String type = "OutputReference"; + + @Key("step_name") + private final String stepName; + + @Key("output_name") + private final String outputName; + + public OutputReference(String stepName, String outputName) { + this.stepName = checkNotNull(stepName); + this.outputName = checkNotNull(outputName); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java new file mode 100644 index 000000000000..98fe4606807a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A {@code PTuple} is an immutable tuple of + * heterogeneously-typed values, "keyed" by {@link TupleTag}s. + * + *

PTuples can be created and accessed like follows: + *

 {@code
+ * String v1 = ...;
+ * Integer v2 = ...;
+ * Iterable v3 = ...;
+ *
+ * // Create TupleTags for each of the values to put in the
+ * // PTuple (the type of the TupleTag enables tracking the
+ * // static type of each of the values in the PTuple):
+ * TupleTag tag1 = new TupleTag<>();
+ * TupleTag tag2 = new TupleTag<>();
+ * TupleTag> tag3 = new TupleTag<>();
+ *
+ * // Create a PTuple with three values:
+ * PTuple povs =
+ *     PTuple.of(tag1, v1)
+ *         .and(tag2, v2)
+ *         .and(tag3, v3);
+ *
+ * // Create an empty PTuple:
+ * Pipeline p = ...;
+ * PTuple povs2 = PTuple.empty(p);
+ *
+ * // Get values out of a PTuple, using the same tags
+ * // that were used to put them in:
+ * Integer vX = povs.get(tag2);
+ * String vY = povs.get(tag1);
+ * Iterable vZ = povs.get(tag3);
+ *
+ * // Get a map of all values in a PTuple:
+ * Map, ?> allVs = povs.getAll();
+ * } 
+ */ +public class PTuple { + /** + * Returns an empty PTuple. + * + *

Longer PTuples can be created by calling + * {@link #and} on the result. + */ + public static PTuple empty() { + return new PTuple(); + } + + /** + * Returns a singleton PTuple containing the given + * value keyed by the given TupleTag. + * + *

Longer PTuples can be created by calling + * {@link #and} on the result. + */ + public static PTuple of(TupleTag tag, V value) { + return empty().and(tag, value); + } + + /** + * Returns a new PTuple that has all the values and + * tags of this PTuple plus the given value and tag. + * + *

The given TupleTag should not already be mapped to a + * value in this PTuple. + */ + public PTuple and(TupleTag tag, V value) { + Map, Object> newMap = new LinkedHashMap, Object>(); + newMap.putAll(valueMap); + newMap.put(tag, value); + return new PTuple(newMap); + } + + /** + * Returns whether this PTuple contains a value with + * the given tag. + */ + public boolean has(TupleTag tag) { + return valueMap.containsKey(tag); + } + + /** + * Returns the value with the given tag in this + * PTuple. Throws IllegalArgumentException if there is no + * such value, i.e., {@code !has(tag)}. + */ + public V get(TupleTag tag) { + if (!has(tag)) { + throw new IllegalArgumentException( + "TupleTag not found in this PTuple"); + } + @SuppressWarnings("unchecked") + V value = (V) valueMap.get(tag); + return value; + } + + /** + * Returns an immutable Map from TupleTag to corresponding + * value, for all the members of this PTuple. + */ + public Map, ?> getAll() { + return valueMap; + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + private final Map, ?> valueMap; + + private PTuple() { + this(new LinkedHashMap()); + } + + private PTuple(Map, ?> valueMap) { + this.valueMap = Collections.unmodifiableMap(valueMap); + } + + /** + * Returns a PTuple with each of the given tags mapping + * to the corresponding value. + * + *

For internal use only. + */ + public static PTuple ofInternal(Map, ?> valueMap) { + return new PTuple(valueMap); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java new file mode 100644 index 000000000000..c108ceb4f157 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java @@ -0,0 +1,307 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.Sleeper; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.common.collect.TreeTraverser; +import com.google.common.hash.Funnels; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; +import com.google.common.io.ByteStreams; +import com.google.common.io.CountingOutputStream; +import com.google.common.io.Files; + +import com.fasterxml.jackson.core.Base64Variants; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +/** Helper routines for packages. */ +public class PackageUtil { + private static final Logger LOG = LoggerFactory.getLogger(PackageUtil.class); + /** + * The initial interval to use between package staging attempts. + */ + private static final long INITIAL_BACKOFF_INTERVAL_MS = 5000L; + /** + * The maximum number of attempts when staging a file. + */ + private static final int MAX_ATTEMPTS = 5; + + /** + * Creates a DataflowPackage containing information about how a classpath element should be + * staged. + * + * @param classpathElement The local path for the classpath element. + * @param stagingDirectory The base location in GCS for staged classpath elements. + * @param overridePackageName If non-null, use the given value as the package name + * instead of generating one automatically. + * @return The package. + */ + public static DataflowPackage createPackage(String classpathElement, + GcsPath stagingDirectory, String overridePackageName) { + try { + File file = new File(classpathElement); + String contentHash = computeContentHash(file); + + // Drop the directory prefixes, and form the filename + hash + extension. + String uniqueName = getUniqueContentName(file, contentHash); + + GcsPath stagingPath = stagingDirectory.resolve(uniqueName); + + DataflowPackage target = new DataflowPackage(); + target.setName(overridePackageName != null ? overridePackageName : uniqueName); + target.setLocation(stagingPath.toResourceName()); + return target; + } catch (IOException e) { + throw new RuntimeException("Package setup failure for " + classpathElement, e); + } + } + + /** + * Transfers the classpath elements to GCS. + * + * @param gcsUtil GCS utility. + * @param classpathElements The elements to stage onto GCS. + * @param gcsStaging The path on GCS to stage the classpath elements to. + * @return A list of cloud workflow packages, each representing a classpath element. + */ + public static List stageClasspathElementsToGcs( + GcsUtil gcsUtil, + Collection classpathElements, + GcsPath gcsStaging) { + return stageClasspathElementsToGcs(gcsUtil, classpathElements, gcsStaging, Sleeper.DEFAULT); + } + + // Visible for testing. + static List stageClasspathElementsToGcs( + GcsUtil gcsUtil, + Collection classpathElements, + GcsPath gcsStaging, + Sleeper retrySleeper) { + ArrayList packages = new ArrayList<>(); + + if (gcsStaging == null) { + throw new IllegalArgumentException( + "Can't stage classpath elements on GCS because no GCS location has been provided"); + } + + for (String classpathElement : classpathElements) { + String packageName = null; + if (classpathElement.contains("=")) { + String[] components = classpathElement.split("=", 2); + packageName = components[0]; + classpathElement = components[1]; + } + + DataflowPackage workflowPackage = createPackage( + classpathElement, gcsStaging, packageName); + + packages.add(workflowPackage); + GcsPath target = GcsPath.fromResourceName(workflowPackage.getLocation()); + + // TODO: Should we attempt to detect the Mime type rather than + // always using MimeTypes.BINARY? + try { + long remoteLength = gcsUtil.fileSize(target); + if (remoteLength >= 0 && remoteLength == getClasspathElementLength(classpathElement)) { + LOG.info("Skipping classpath element already on gcs: {} at {}", classpathElement, target); + continue; + } + + // Upload file, retrying on failure. + BackOff backoff = new AttemptBoundedExponentialBackOff( + MAX_ATTEMPTS, + INITIAL_BACKOFF_INTERVAL_MS); + while (true) { + try { + LOG.info("Uploading classpath element {} to {}", classpathElement, target); + try (WritableByteChannel writer = gcsUtil.create(target, MimeTypes.BINARY)) { + copyContent(classpathElement, writer); + } + break; + } catch (IOException e) { + if (BackOffUtils.next(retrySleeper, backoff)) { + LOG.warn("Upload attempt failed, will retry staging of classpath: {}", + classpathElement, e); + } else { + // Rethrow last error, to be included as a cause in the catch below. + LOG.error("Upload failed, will NOT retry staging of classpath: {}", + classpathElement, e); + throw e; + } + } + } + } catch (Exception e) { + throw new RuntimeException("Could not stage classpath element: " + classpathElement, e); + } + } + + return packages; + } + + /** + * If classpathElement is a file, then the files length is returned, otherwise the length + * of the copied stream is returned. + * + * @param classpathElement The local path for the classpath element. + * @return The length of the classpathElement. + */ + private static long getClasspathElementLength(String classpathElement) throws IOException { + File file = new File(classpathElement); + if (file.isFile()) { + return file.length(); + } + + CountingOutputStream countingOutputStream = + new CountingOutputStream(ByteStreams.nullOutputStream()); + try (WritableByteChannel channel = Channels.newChannel(countingOutputStream)) { + copyContent(classpathElement, channel); + } + return countingOutputStream.getCount(); + } + + /** + * Returns a unique name for a file with a given content hash. + *

+ * Directory paths are removed. Example: + *

+   * dir="a/b/c/d", contentHash="f000" => d-f000.zip
+   * file="a/b/c/d.txt", contentHash="f000" => d-f000.txt
+   * file="a/b/c/d", contentHash="f000" => d-f000
+   * 
+ */ + static String getUniqueContentName(File classpathElement, String contentHash) { + String fileName = Files.getNameWithoutExtension(classpathElement.getAbsolutePath()); + String fileExtension = Files.getFileExtension(classpathElement.getAbsolutePath()); + if (classpathElement.isDirectory()) { + return fileName + "-" + contentHash + ".zip"; + } else if (fileExtension.isEmpty()) { + return fileName + "-" + contentHash; + } + return fileName + "-" + contentHash + "." + fileExtension; + } + + /** + * Computes a message digest of the file/directory contents, returning a base64 string which is + * suitable for use in URLs. + */ + private static String computeContentHash(File classpathElement) throws IOException { + TreeTraverser files = Files.fileTreeTraverser(); + Hasher hasher = Hashing.md5().newHasher(); + for (File currentFile : files.preOrderTraversal(classpathElement)) { + String relativePath = relativize(currentFile, classpathElement); + hasher.putString(relativePath, StandardCharsets.UTF_8); + if (currentFile.isDirectory()) { + hasher.putLong(-1L); + continue; + } + hasher.putLong(currentFile.length()); + Files.asByteSource(currentFile).copyTo(Funnels.asOutputStream(hasher)); + } + return Base64Variants.MODIFIED_FOR_URL.encode(hasher.hash().asBytes()); + } + + /** + * Copies the contents of the classpathElement to the output channel. + *

+ * If the classpathElement is a directory, a Zip stream is constructed on the fly, + * otherwise the file contents are copied as-is. + *

+ * The output channel is not closed. + */ + private static void copyContent(String classpathElement, WritableByteChannel outputChannel) + throws IOException { + final File classpathElementFile = new File(classpathElement); + if (!classpathElementFile.isDirectory()) { + Files.asByteSource(classpathElementFile).copyTo(Channels.newOutputStream(outputChannel)); + return; + } + + ZipOutputStream zos = new ZipOutputStream(Channels.newOutputStream(outputChannel)); + zipDirectoryRecursive(classpathElementFile, classpathElementFile, zos); + zos.finish(); + } + + /** + * Private helper function for zipping files. This one goes recursively through the input + * directory and all of its subdirectories and adds the single zip entries. + * + * @param file the file or directory to be added to the zip file. + * @param root each file uses the root directory to generate its relative path within the zip. + * @param zos the zipstream to write to. + * @throws IOException the zipping failed, e.g. because the output was not writable. + */ + private static void zipDirectoryRecursive(File file, File root, ZipOutputStream zos) + throws IOException { + final String entryName = relativize(file, root); + if (file.isDirectory()) { + // We are hitting a directory. Start the recursion. + // Add the empty entry if it is a subdirectory and the subdirectory has no children. + // Don't add it otherwise, as this is incompatible with certain implementations of unzip. + if (file.list().length == 0 && !file.equals(root)) { + ZipEntry entry = new ZipEntry(entryName + "/"); + zos.putNextEntry(entry); + } else { + // loop through the directory content, and zip the files + for (File currentFile : file.listFiles()) { + zipDirectoryRecursive(currentFile, root, zos); + } + } + } else { + // Put the next zip-entry into the zipoutputstream. + ZipEntry entry = new ZipEntry(entryName); + zos.putNextEntry(entry); + Files.asByteSource(file).copyTo(zos); + } + } + + /** + * Constructs a relative path between file and root. + *

+ * This function will attempt to use {@link java.nio.file.Path#relativize} and + * will fallback to using {@link java.net.URI#relativize} in AppEngine. + * + * @param file The file for which the relative path is being constructed for. + * @param root The root from which the relative path should be constructed. + * @return The relative path between the file and root. + */ + private static String relativize(File file, File root) { + if (AppEngineEnvironment.IS_APP_ENGINE) { + // AppEngine doesn't allow for java.nio.file.Path to be used so we rely on + // using URIs, but URIs are broken for UNC paths which AppEngine doesn't + // use. See for more details: http://wiki.eclipse.org/Eclipse/UNC_Paths + return root.toURI().relativize(file.toURI()).getPath(); + } + return root.toPath().relativize(file.toPath()).toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PartitionBufferingWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PartitionBufferingWindowSet.java new file mode 100644 index 000000000000..96b2ece5cf98 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PartitionBufferingWindowSet.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.WindowUtils.bufferTag; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.util.Collection; + +/** + * A WindowSet where each value is placed in exactly one window, + * and windows are never merged, deleted, or flushed early, and the + * WindowSet itself is never exposed to user code, allowing + * a much simpler (and cheaper) implementation. + * + * This WindowSet only works with {@link StreamingGroupAlsoByWindowsDoFn}. + */ +class PartitionBufferingWindowSet + extends AbstractWindowSet, W> { + PartitionBufferingWindowSet( + K key, + WindowingFn windowingFn, + Coder inputCoder, + DoFnProcessContext>> context, + ActiveWindowManager activeWindowManager) { + super(key, windowingFn, inputCoder, context, activeWindowManager); + } + + @Override + public void put(W window, V value) throws Exception { + context.context.stepContext.writeToTagList( + bufferTag(window, windowingFn.windowCoder(), inputCoder), value, context.timestamp()); + // Adds the window even if it is already present, relying on the streaming backend to + // de-deduplicate. + activeWindowManager.addWindow(window); + } + + @Override + public void remove(W window) throws Exception { + CodedTupleTag tag = bufferTag(window, windowingFn.windowCoder(), inputCoder); + context.context.stepContext.deleteTagList(tag); + } + + @Override + public void merge(Collection otherWindows, W newWindow) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean contains(W window) { + throw new UnsupportedOperationException(); + } + + @Override + protected Iterable finalValue(W window) throws Exception { + CodedTupleTag tag = bufferTag(window, windowingFn.windowCoder(), inputCoder); + Iterable result = context.context.stepContext.readTagList(tag); + if (result == null) { + throw new IllegalStateException("finalValue called for non-existent window"); + } + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java new file mode 100644 index 000000000000..85a81cdeff9c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Constant property names used by the SDK in CloudWorkflow specifications. + */ +public class PropertyNames { + public static final String APPEND_TRAILING_NEWLINES = "append_trailing_newlines"; + public static final String BIGQUERY_CREATE_DISPOSITION = "create_disposition"; + public static final String BIGQUERY_DATASET = "dataset"; + public static final String BIGQUERY_PROJECT = "project"; + public static final String BIGQUERY_SCHEMA = "schema"; + public static final String BIGQUERY_TABLE = "table"; + public static final String BIGQUERY_WRITE_DISPOSITION = "write_disposition"; + public static final String CO_GBK_RESULT_SCHEMA = "co_gbk_result_schema"; + public static final String COMBINE_FN = "combine_fn"; + public static final String COMPONENT_ENCODINGS = "component_encodings"; + public static final String CUSTOM_SOURCE_FORMAT = "custom_source"; + public static final String CUSTOM_SOURCE_STEP_INPUT = "custom_source_step_input"; + public static final String CUSTOM_SOURCE_SPEC = "spec"; + public static final String CUSTOM_SOURCE_METADATA = "metadata"; + public static final String CUSTOM_SOURCE_DOES_NOT_NEED_SPLITTING = "does_not_need_splitting"; + public static final String CUSTOM_SOURCE_PRODUCES_SORTED_KEYS = "produces_sorted_keys"; + public static final String CUSTOM_SOURCE_IS_INFINITE = "is_infinite"; + public static final String CUSTOM_SOURCE_ESTIMATED_SIZE_BYTES = "estimated_size_bytes"; + public static final String ELEMENT = "element"; + public static final String ELEMENTS = "elements"; + public static final String ENCODING = "encoding"; + public static final String END_INDEX = "end_index"; + public static final String END_OFFSET = "end_offset"; + public static final String END_SHUFFLE_POSITION = "end_shuffle_position"; + public static final String ENVIRONMENT_VERSION_JOB_TYPE_KEY = "job_type"; + public static final String ENVIRONMENT_VERSION_MAJOR_KEY = "major"; + public static final String FILENAME = "filename"; + public static final String FILENAME_PREFIX = "filename_prefix"; + public static final String FILENAME_SUFFIX = "filename_suffix"; + public static final String FILEPATTERN = "filepattern"; + public static final String FOOTER = "footer"; + public static final String FORMAT = "format"; + public static final String HEADER = "header"; + public static final String INPUTS = "inputs"; + public static final String INPUT_CODER = "input_coder"; + public static final String IS_GENERATED = "is_generated"; + public static final String IS_PAIR_LIKE = "is_pair_like"; + public static final String IS_STREAM_LIKE = "is_stream_like"; + public static final String IS_WRAPPER = "is_wrapper"; + public static final String NON_PARALLEL_INPUTS = "non_parallel_inputs"; + public static final String NUM_SHARDS = "num_shards"; + public static final String OBJECT_TYPE_NAME = "@type"; + public static final String OUTPUT = "output"; + public static final String OUTPUT_INFO = "output_info"; + public static final String OUTPUT_NAME = "output_name"; + public static final String PARALLEL_INPUT = "parallel_input"; + public static final String PHASE = "phase"; + public static final String PUBSUB_SUBSCRIPTION = "pubsub_subscription"; + public static final String PUBSUB_TOPIC = "pubsub_topic"; + public static final String SCALAR_FIELD_NAME = "value"; + public static final String SERIALIZED_FN = "serialized_fn"; + public static final String SHARD_NAME_TEMPLATE = "shard_template"; + public static final String SHUFFLE_KIND = "shuffle_kind"; + public static final String SHUFFLE_READER_CONFIG = "shuffle_reader_config"; + public static final String SHUFFLE_WRITER_CONFIG = "shuffle_writer_config"; + public static final String START_INDEX = "start_index"; + public static final String START_OFFSET = "start_offset"; + public static final String START_SHUFFLE_POSITION = "start_shuffle_position"; + public static final String STRIP_TRAILING_NEWLINES = "strip_trailing_newlines"; + public static final String TUPLE_TAGS = "tuple_tags"; + public static final String USER_FN = "user_fn"; + public static final String USER_NAME = "user_name"; + public static final String USES_KEYED_STATE = "uses_keyed_state"; + public static final String VALUE = "value"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java new file mode 100644 index 000000000000..34d40f147079 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.http.HttpBackOffIOExceptionHandler; +import com.google.api.client.http.HttpBackOffUnsuccessfulResponseHandler; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpUnsuccessfulResponseHandler; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Implements a request initializer which adds retry handlers to all + * HttpRequests. + * + * This allows chaining through to another HttpRequestInitializer, since + * clients have exactly one HttpRequestInitializer, and Credential is also + * a required HttpRequestInitializer. + */ +public class RetryHttpRequestInitializer implements HttpRequestInitializer { + + private static final Logger LOG = LoggerFactory.getLogger(RetryHttpRequestInitializer.class); + + /** + * Http response codes that should be silently ignored. + */ + private static final Set IGNORED_RESPONSE_CODES = new HashSet<>( + Arrays.asList(307 /* Redirect, handled by Apiary client */, + 308 /* Resume Incomplete, handled by Apiary client */)); + + /** + * Http response timeout to use for hanging gets. + */ + private static final int HANGING_GET_TIMEOUT_SEC = 80; + + private static class LoggingHttpBackOffIOExceptionHandler + extends HttpBackOffIOExceptionHandler { + public LoggingHttpBackOffIOExceptionHandler(BackOff backOff) { + super(backOff); + } + + @Override + public boolean handleIOException(HttpRequest request, boolean supportsRetry) + throws IOException { + boolean willRetry = super.handleIOException(request, supportsRetry); + if (willRetry) { + LOG.info("Request failed with IOException, will retry: {}", request.getUrl()); + } else { + LOG.info("Request failed with IOException, will NOT retry: {}", request.getUrl()); + } + return willRetry; + } + } + + private static class LoggingHttpBackoffUnsuccessfulResponseHandler + implements HttpUnsuccessfulResponseHandler { + private final HttpBackOffUnsuccessfulResponseHandler handler; + + public LoggingHttpBackoffUnsuccessfulResponseHandler(BackOff backoff, + Sleeper sleeper) { + handler = new HttpBackOffUnsuccessfulResponseHandler(backoff); + handler.setSleeper(sleeper); + handler.setBackOffRequired( + new HttpBackOffUnsuccessfulResponseHandler.BackOffRequired() { + @Override + public boolean isRequired(HttpResponse response) { + int statusCode = response.getStatusCode(); + return (statusCode / 100 == 5) || // 5xx: server error + statusCode == 429; // 429: Too many requests + } + }); + } + + @Override + public boolean handleResponse(HttpRequest request, HttpResponse response, + boolean supportsRetry) throws IOException { + boolean retry = handler.handleResponse(request, response, supportsRetry); + if (retry) { + LOG.info("Request failed with code {} will retry: {}", + response.getStatusCode(), request.getUrl()); + + } else if (!IGNORED_RESPONSE_CODES.contains(response.getStatusCode())) { + LOG.info("Request failed with code {}, will NOT retry: {}", + response.getStatusCode(), request.getUrl()); + } + + return retry; + } + } + + private final HttpRequestInitializer chained; + + private final NanoClock nanoClock; // used for testing + + private final Sleeper sleeper; // used for testing + + /** + * @param chained a downstream HttpRequestInitializer, which will also be + * applied to HttpRequest initialization. May be null. + */ + public RetryHttpRequestInitializer(@Nullable HttpRequestInitializer chained) { + this(chained, NanoClock.SYSTEM, Sleeper.DEFAULT); + } + + public RetryHttpRequestInitializer(@Nullable HttpRequestInitializer chained, + NanoClock nanoClock, Sleeper sleeper) { + this.chained = chained; + this.nanoClock = nanoClock; + this.sleeper = sleeper; + } + + @Override + public void initialize(HttpRequest request) throws IOException { + if (chained != null) { + chained.initialize(request); + } + + // Set a timeout for hanging-gets. + // TODO: Do this exclusively for work requests. + request.setReadTimeout(HANGING_GET_TIMEOUT_SEC * 1000); + + // Back off on retryable http errors. + request.setUnsuccessfulResponseHandler( + // A back-off multiplier of 2 raises the maximum request retrying time + // to approximately 5 minutes (keeping other back-off parameters to + // their default values). + new LoggingHttpBackoffUnsuccessfulResponseHandler( + new ExponentialBackOff.Builder().setNanoClock(nanoClock) + .setMultiplier(2).build(), + sleeper)); + + // Retry immediately on IOExceptions. + LoggingHttpBackOffIOExceptionHandler loggingBackoffHandler = + new LoggingHttpBackOffIOExceptionHandler(BackOff.ZERO_BACKOFF); + request.setIOExceptionHandler(loggingBackoffHandler); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java new file mode 100644 index 000000000000..9ee09c8608ab --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.decodeFromByteArray; +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Arrays; + +/** + * Utilities for working with Serializables. + */ +public class SerializableUtils { + /** + * Serializes the argument into an array of bytes, and returns it. + * + * @throws IllegalArgumentException if there are errors when serializing + */ + public static byte[] serializeToByteArray(Serializable value) { + try { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(buffer)) { + oos.writeObject(value); + } + return buffer.toByteArray(); + } catch (IOException exn) { + throw new IllegalArgumentException( + "unable to serialize " + value, + exn); + } + } + + /** + * Deserializes an object from the given array of bytes, e.g., as + * serialized using {@link #serializeToByteArray}, and returns it. + * + * @throws IllegalArgumentException if there are errors when + * deserializing, using the provided description to identify what + * was being deserialized + */ + public static Object deserializeFromByteArray(byte[] encodedValue, + String description) { + try { + try (ObjectInputStream ois = new ObjectInputStream( + new ByteArrayInputStream(encodedValue))) { + return ois.readObject(); + } + } catch (IOException | ClassNotFoundException exn) { + throw new IllegalArgumentException( + "unable to deserialize " + description, + exn); + } + } + + public static T ensureSerializable(T value) { + @SuppressWarnings("unchecked") + T copy = (T) deserializeFromByteArray(serializeToByteArray(value), + value.toString()); + return copy; + } + + /** + * Serializes a Coder and verifies that it can be correctly deserialized. + *

+ * Throws a RuntimeException if serialized Coder cannot be deserialized, or + * if the deserialized instance is not equal to the original. + *

+ * @return the serialized Coder, as a {@link CloudObject} + */ + public static CloudObject ensureSerializable(Coder coder) { + CloudObject cloudObject = coder.asCloudObject(); + + Coder decoded; + try { + decoded = Serializer.deserialize(cloudObject, Coder.class); + } catch (RuntimeException e) { + throw new RuntimeException( + String.format("Unable to deserialize Coder: %s. " + + "Check that a suitable constructor is defined. " + + "See Coder for details.", coder), e + ); + } + Preconditions.checkState(coder.equals(decoded), + String.format("Coder not equal to original after serialization, " + + "indicating that the Coder may not implement serialization " + + "correctly. Before: %s, after: %s, cloud encoding: %s", + coder, decoded, cloudObject)); + + return cloudObject; + } + + /** + * Serializes an arbitrary T with the given Coder and verifies + * that it can be correctly deserialized. + */ + public static T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + byte[] encodedValue; + try { + encodedValue = encodeToByteArray(coder, value); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + errorContext + ": unable to encode value " + + value + " using " + coder, + exn); + } + try { + return decodeFromByteArray(coder, encodedValue); + } catch (CoderException exn) { + // TODO: Put in better encoded byte array printing: + // use printable chars with escapes instead of codes, and + // truncate if too long. + throw new IllegalArgumentException( + errorContext + ": unable to decode " + Arrays.toString(encodedValue) + + ", encoding of value " + value + ", using " + coder, + exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java new file mode 100644 index 000000000000..42071ec467ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.Module; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Utility for converting objects between Java and Cloud representations. + */ +public final class Serializer { + // Delay initialization of statics until the first call to Serializer. + private static class SingletonHelper { + static final ObjectMapper OBJECT_MAPPER = createObjectMapper(); + static final ObjectMapper TREE_MAPPER = createTreeMapper(); + + /** + * Creates the object mapper which will be used for serializing Google API + * client maps into Jackson trees. + */ + private static ObjectMapper createTreeMapper() { + return new ObjectMapper(); + } + + /** + * Creates the object mapper which will be used for deserializing Jackson + * trees into objects. + */ + private static ObjectMapper createObjectMapper() { + ObjectMapper m = new ObjectMapper(); + // Ignore properties which are not used by the object. + m.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); + + // For parameters of type Object, use the @type property to determine the + // class to instantiate. + // + // TODO: It would be ideal to do this for all non-final classes. The + // problem with using DefaultTyping.NON_FINAL is that it insists on having + // type information in the JSON for classes with useful default + // implementations, such as List. Ideally, we'd combine these defaults + // with available type information if that information's present. + m.enableDefaultTypingAsProperty( + ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT, + PropertyNames.OBJECT_TYPE_NAME); + + m.registerModule(new CoderUtils.Jackson2Module()); + + return m; + } + } + + /** + * Registers a module to use during object deserialization. + */ + public static void registerModule(Module module) { + SingletonHelper.OBJECT_MAPPER.registerModule(module); + } + + /** + * Deserializes an object from a Dataflow structured encoding (represented in + * Java as a map). + *

+ * The standard Dataflow SDK object serialization protocol is based on JSON. + * Data is typically encoded as a JSON object whose fields represent the + * object's data. + *

+ * The actual deserialization is performed by Jackson, which can deserialize + * public fields, use JavaBean setters, or use injection annotations to + * indicate how to construct the object. The {@link ObjectMapper} used is + * configured to use the "@type" field as the name of the class to instantiate + * (supporting polymorphic types), and may be further configured by + * annotations or via {@link #registerModule}. + *

+ * @see + * Jackson Data-Binding + * @see + * Jackson-Annotations + * @param serialized the object in untyped decoded form (i.e. a nested {@link Map}) + * @param clazz the expected object class + */ + public static T deserialize(Map serialized, Class clazz) { + try { + return SingletonHelper.OBJECT_MAPPER.treeToValue( + SingletonHelper.TREE_MAPPER.valueToTree( + deserializeCloudKnownTypes(serialized)), + clazz); + } catch (JsonProcessingException e) { + throw new RuntimeException( + "Unable to deserialize class " + clazz, e); + } + } + + /** + * Recursively walks the supplied map, looking for well-known cloud type + * information (keyed as {@link PropertyNames#OBJECT_TYPE_NAME}, matching a + * URI value from the {@link CloudKnownType} enum. Upon finding this type + * information, it converts it into the correspondingly typed Java value. + */ + private static Object deserializeCloudKnownTypes(Object src) { + if (src instanceof Map) { + Map srcMap = (Map) src; + @Nullable Object value = srcMap.get(PropertyNames.SCALAR_FIELD_NAME); + @Nullable CloudKnownType type = + CloudKnownType.forUri((String) srcMap.get(PropertyNames.OBJECT_TYPE_NAME)); + if (type != null && value != null) { + // It's a value of a well-known cloud type; let the known type handler + // handle the translation. + Object result = type.parse(value, type.defaultClass()); + return result; + } + // Otherwise, it's just an ordinary map. + Map dest = new HashMap<>(srcMap.size()); + for (Map.Entry entry : srcMap.entrySet()) { + dest.put(entry.getKey(), deserializeCloudKnownTypes(entry.getValue())); + } + return dest; + } + if (src instanceof List) { + List srcList = (List) src; + List dest = new ArrayList<>(srcList.size()); + for (Object obj : srcList) { + dest.add(deserializeCloudKnownTypes(obj)); + } + return dest; + } + // Neither a Map nor a List; no translation needed. + return src; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java new file mode 100644 index 000000000000..4a3322b34535 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; + +/** + * Implements a WritableByteChannel which may contain multiple output shards. + * + *

This provides {@link #writeToShard}, which takes a shard number for + * writing to a particular shard. + * + *

The channel is considered open if all downstream channels are open, and + * closes all downstream channels when closed. + */ +public class ShardingWritableByteChannel implements WritableByteChannel { + + /** + * Special shard number which causes a write to all shards. + */ + public static final int ALL_SHARDS = -2; + + + private final ArrayList writers = new ArrayList<>(); + + /** + * Returns the number of output shards. + */ + public int getNumShards() { + return writers.size(); + } + + /** + * Adds another shard output channel. + */ + public void addChannel(WritableByteChannel writer) { + writers.add(writer); + } + + /** + * Returns the WritableByteChannel associated with the given shard number. + */ + public WritableByteChannel getChannel(int shardNum) { + return writers.get(shardNum); + } + + /** + * Writes the buffer to the given shard. + * + *

This does not change the current output shard. + * + * @return The total number of bytes written. If the shard number is + * {@link #ALL_SHARDS}, then the total is the sum of each individual shard + * write. + */ + public int writeToShard(int shardNum, ByteBuffer src) throws IOException { + if (shardNum >= 0) { + return writers.get(shardNum).write(src); + } + + switch (shardNum) { + case ALL_SHARDS: + int size = 0; + for (WritableByteChannel writer : writers) { + size += writer.write(src); + } + return size; + + default: + throw new IllegalArgumentException("Illegal shard number: " + shardNum); + } + } + + /** + * Writes a buffer to all shards. + * + *

Same as calling {@code writeToShard(ALL_SHARDS, buf)}. + */ + @Override + public int write(ByteBuffer src) throws IOException { + return writeToShard(ALL_SHARDS, src); + } + + @Override + public boolean isOpen() { + for (WritableByteChannel writer : writers) { + if (!writer.isOpen()) { + return false; + } + } + + return true; + } + + @Override + public void close() throws IOException { + for (WritableByteChannel writer : writers) { + writer.close(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFn.java new file mode 100644 index 000000000000..dcfd58aee92d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFn.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PartitioningWindowingFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.values.KV; + +import java.io.IOException; + +/** + * DoFn that merges windows and groups elements in those windows. + * + * @param key type + * @param input value element type + * @param output value element type + * @param window type + */ +public class StreamingGroupAlsoByWindowsDoFn + extends DoFn>, KV> implements DoFn.RequiresKeyedState { + + protected WindowingFn windowingFn; + protected Coder inputCoder; + + protected StreamingGroupAlsoByWindowsDoFn( + WindowingFn windowingFn, + Coder inputCoder) { + this.windowingFn = windowingFn; + this.inputCoder = inputCoder; + } + + public static + StreamingGroupAlsoByWindowsDoFn create( + WindowingFn windowingFn, + Coder inputCoder) { + return new StreamingGroupAlsoByWindowsDoFn<>(windowingFn, inputCoder); + } + + private AbstractWindowSet createWindowSet( + K key, + DoFnProcessContext> context, + AbstractWindowSet.ActiveWindowManager activeWindowManager) throws Exception { + if (windowingFn instanceof PartitioningWindowingFn) { + return new PartitionBufferingWindowSet( + key, windowingFn, inputCoder, context, activeWindowManager); + } else { + return new BufferingWindowSet(key, windowingFn, inputCoder, context, activeWindowManager); + } + } + + @Override + public void processElement(ProcessContext processContext) throws Exception { + DoFnProcessContext>, KV> context = + (DoFnProcessContext>, KV>) processContext; + if (!context.element().isTimer()) { + KV element = context.element().element(); + K key = element.getKey(); + VI value = element.getValue(); + AbstractWindowSet windowSet = createWindowSet( + key, context, new StreamingActiveWindowManager<>(context, windowingFn.windowCoder())); + + for (BoundedWindow window : context.windows()) { + windowSet.put((W) window, value); + } + + windowSet.flush(); + } else { + TimerOrElement timer = context.element(); + AbstractWindowSet windowSet = createWindowSet( + (K) timer.key(), context, new StreamingActiveWindowManager<>( + context, windowingFn.windowCoder())); + + // Attempt to merge windows before emitting; that may remove the current window under + // consideration. + ((WindowingFn) windowingFn) + .mergeWindows(new AbstractWindowSet.WindowMergeContext(windowSet, windowingFn)); + + W window = WindowUtils.windowFromString(timer.tag(), windowingFn.windowCoder()); + boolean windowExists; + try { + windowExists = windowSet.contains(window); + } catch (UnsupportedOperationException e) { + windowExists = true; + } + if (windowExists) { + windowSet.markCompleted(window); + windowSet.flush(); + } + } + } + + private static class StreamingActiveWindowManager + implements AbstractWindowSet.ActiveWindowManager { + DoFnProcessContext context; + Coder coder; + + StreamingActiveWindowManager( + DoFnProcessContext context, + Coder coder) { + this.context = context; + this.coder = coder; + } + + @Override + public void addWindow(W window) throws IOException { + context.context.stepContext.getExecutionContext().setTimer( + WindowUtils.windowToString(window, coder), window.maxTimestamp()); + } + + @Override + public void removeWindow(W window) throws IOException { + context.context.stepContext.getExecutionContext().deleteTimer( + WindowUtils.windowToString(window, coder)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java new file mode 100644 index 000000000000..382683c2de3c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.base.Joiner; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Utilities for working with JSON and other human-readable string formats. + */ +public class StringUtils { + /** + * Converts the given array of bytes into a legal JSON string. + * + * Uses a simple strategy of converting each byte to a single char, + * except for non-printable chars, non-ASCII chars, and '%', '\', + * and '"', which are encoded as three chars in '%xx' format, where + * 'xx' is the hexadecimal encoding of the byte. + */ + public static String byteArrayToJsonString(byte[] bytes) { + StringBuilder sb = new StringBuilder(bytes.length * 2); + for (byte b : bytes) { + if (b >= 32 && b < 127) { + // A printable ascii character. + char c = (char) b; + if (c != '%' && c != '\\' && c != '\"') { + // Not an escape prefix or special character, either. + // Send through unchanged. + sb.append(c); + continue; + } + } + // Send through escaped. Use '%xx' format. + sb.append(String.format("%%%02x", b)); + } + return sb.toString(); + } + + /** + * Converts the given string, encoded using {@link #byteArrayToJsonString}, + * into a byte array. + * + * @throws IllegalArgumentException if the argument string is not legal + */ + public static byte[] jsonStringToByteArray(String string) { + List bytes = new ArrayList<>(); + for (int i = 0; i < string.length(); ) { + char c = string.charAt(i); + Byte b; + if (c == '%') { + // Escaped. Expect '%xx' format. + try { + b = (byte) Integer.parseInt(string.substring(i + 1, i + 3), 16); + } catch (IndexOutOfBoundsException | NumberFormatException exn) { + throw new IllegalArgumentException( + "not in legal encoded format; " + + "substring [" + i + ".." + (i + 2) + "] not in format \"%xx\"", + exn); + } + i += 3; + } else { + // Send through unchanged. + b = (byte) c; + i++; + } + bytes.add(b); + } + byte[] byteArray = new byte[bytes.size()]; + int i = 0; + for (Byte b : bytes) { + byteArray[i++] = b; + } + return byteArray; + } + + private static final String[] STANDARD_NAME_SUFFIXES = + new String[]{"DoFn", "Fn"}; + + /** + * Pattern to match a non-anonymous inner class. + * Eg, matches "Foo$Bar", or even "Foo$1$Bar", but not "Foo$1" or "Foo$1$2". + */ + private static final Pattern NAMED_INNER_CLASS = + Pattern.compile(".+\\$(?[^0-9].*)"); + + /** + * Returns a simple name for a class. + * + *

Note: this is non-invertible - the name may be simplified to an + * extent that it cannot be mapped back to the original class. + * + *

This can be used to generate human-readable transform names. It + * removes the package from the name, and removes common suffixes. + * + *

Examples: + *

    + *
  • {@code some.package.WordSummaryDoFn} -> "WordSummary" + *
  • {@code another.package.PairingFn} -> "Pairing" + *
+ */ + public static String approximateSimpleName(Class clazz) { + String fullName = clazz.getName(); + String shortName = fullName.substring(fullName.lastIndexOf('.') + 1); + + // Simplify inner class name by dropping outer class prefixes. + Matcher m = NAMED_INNER_CLASS.matcher(shortName); + if (m.matches()) { + shortName = m.group("INNER"); + } + + // Drop common suffixes for each named component. + String[] names = shortName.split("\\$"); + for (int i = 0; i < names.length; i++) { + names[i] = simplifyNameComponent(names[i]); + } + + return Joiner.on('$').join(names); + } + + private static String simplifyNameComponent(String name) { + for (String suffix : STANDARD_NAME_SUFFIXES) { + if (name.endsWith(suffix) && name.length() > suffix.length()) { + return name.substring(0, name.length() - suffix.length()); + } + } + return name; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java new file mode 100644 index 000000000000..8fb2e834f19e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java @@ -0,0 +1,345 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A collection of static methods for manipulating datastructure representations + * transferred via the Dataflow API. + */ +public final class Structs { + private Structs() {} // Non-instantiable + + public static String getString(Map map, String name) throws Exception { + return getValue(map, name, String.class, "a string"); + } + + public static String getString( + Map map, String name, @Nullable String defaultValue) + throws Exception { + return getValue(map, name, String.class, "a string", defaultValue); + } + + public static byte[] getBytes(Map map, String name) throws Exception { + @Nullable byte[] result = getBytes(map, name, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + public static byte[] getBytes(Map map, String name, @Nullable byte[] defaultValue) + throws Exception { + @Nullable String jsonString = getString(map, name, null); + if (jsonString == null) { + return defaultValue; + } + // TODO: Need to agree on a format for encoding bytes in + // a string that can be sent over the Apiary wire, over the cloud + // map task work API. base64 encoding seems pretty common. Switch to it? + return StringUtils.jsonStringToByteArray(jsonString); + } + + public static Boolean getBoolean(Map map, String name) throws Exception { + return getValue(map, name, Boolean.class, "a boolean"); + } + + @Nullable + public static Boolean getBoolean( + Map map, String name, @Nullable Boolean defaultValue) + throws Exception { + return getValue(map, name, Boolean.class, "a boolean", defaultValue); + } + + public static Long getLong(Map map, String name) throws Exception { + return getValue(map, name, Long.class, "an int"); + } + + @Nullable + public static Long getLong(Map map, String name, @Nullable Long defaultValue) + throws Exception { + return getValue(map, name, Long.class, "an int", defaultValue); + } + + @Nullable + public static List getStrings( + Map map, String name, @Nullable List defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "a string or a list"); + } + return defaultValue; + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a list of strings, + // this is an empty list. + return Collections.emptyList(); + } + @Nullable String singletonString = decodeValue(value, String.class); + if (singletonString != null) { + return Collections.singletonList(singletonString); + } + if (!(value instanceof List)) { + throw new IncorrectTypeException(name, map, "a string or a list"); + } + @SuppressWarnings("unchecked") + List elements = (List) value; + List result = new ArrayList<>(elements.size()); + for (Object o : elements) { + @Nullable String s = decodeValue(o, String.class); + if (s == null) { + throw new IncorrectTypeException(name, map, "a list of strings"); + } + result.add(s); + } + return result; + } + + public static Map getObject(Map map, String name) + throws Exception { + @Nullable Map result = getObject(map, name, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + public static Map getObject( + Map map, String name, @Nullable Map defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "an object"); + } + return defaultValue; + } + return checkObject(value, map, name); + } + + private static Map checkObject( + Object value, Map map, String name) throws Exception { + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as an object, this is an + // empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "an object (not a map)"); + } + @SuppressWarnings("unchecked") + Map mapValue = (Map) value; + if (!mapValue.containsKey(PropertyNames.OBJECT_TYPE_NAME)) { + throw new IncorrectTypeException(name, map, + "an object (no \"" + PropertyNames.OBJECT_TYPE_NAME + "\" field)"); + } + return mapValue; + } + + public static Map getDictionary( + Map map, String name) throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + throw new ParameterNotFoundException(name, map); + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a dictionary, this is + // an empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + @SuppressWarnings("unchecked") + Map result = (Map) value; + return result; + } + + @Nullable + public static Map getDictionary( + Map map, String name, @Nullable Map defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + return defaultValue; + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a dictionary, this is + // an empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + @SuppressWarnings("unchecked") + Map result = (Map) value; + return result; + } + + // Builder operations. + + public static void addString(Map map, String name, String value) { + addObject(map, name, CloudObject.forString(value)); + } + + public static void addBoolean(Map map, String name, boolean value) { + addObject(map, name, CloudObject.forBoolean(value)); + } + + public static void addLong(Map map, String name, long value) { + addObject(map, name, CloudObject.forInteger(value)); + } + + public static void addObject( + Map map, String name, Map value) { + map.put(name, value); + } + + public static void addNull(Map map, String name) { + map.put(name, Data.nullOf(Object.class)); + } + + public static void addLongs(Map map, String name, long... longs) { + List> elements = new ArrayList<>(longs.length); + for (Long value : longs) { + elements.add(CloudObject.forInteger(value)); + } + map.put(name, elements); + } + + public static void addList( + Map map, String name, List> elements) { + map.put(name, elements); + } + + public static void addStringList(Map map, String name, List elements) { + ArrayList objects = new ArrayList<>(elements.size()); + for (String element : elements) { + objects.add(CloudObject.forString(element)); + } + addList(map, name, objects); + } + + public static > void addList( + Map map, String name, T[] elements) { + map.put(name, Arrays.asList(elements)); + } + + public static void addDictionary( + Map map, String name, Map value) { + map.put(name, value); + } + + public static void addDouble(Map map, String name, Double value) { + addObject(map, name, CloudObject.forFloat(value)); + } + + // Helper methods for a few of the accessor methods. + + private static T getValue(Map map, String name, Class clazz, String type) + throws Exception { + @Nullable T result = getValue(map, name, clazz, type, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + private static T getValue( + Map map, String name, Class clazz, String type, @Nullable T defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, type); + } + return defaultValue; + } + T result = decodeValue(value, clazz); + if (result == null) { + // The value exists, but can't be decoded. + throw new IncorrectTypeException(name, map, type); + } + return result; + } + + @Nullable + private static T decodeValue(Object value, Class clazz) { + try { + if (value.getClass() == clazz) { + // decodeValue() is only called for final classes; if the class matches, + // it's safe to just return the value, and if it doesn't match, decoding + // is needed. + return clazz.cast(value); + } + if (!(value instanceof Map)) { + return null; + } + @SuppressWarnings("unchecked") + Map map = (Map) value; + @Nullable String typeName = (String) map.get(PropertyNames.OBJECT_TYPE_NAME); + if (typeName == null) { + return null; + } + @Nullable CloudKnownType knownType = CloudKnownType.forUri(typeName); + if (knownType == null) { + return null; + } + @Nullable Object scalar = map.get(PropertyNames.SCALAR_FIELD_NAME); + if (scalar == null) { + return null; + } + return knownType.parse(scalar, clazz); + } catch (ClassCastException e) { + // If any class cast fails during decoding, the value's not decodable. + return null; + } + } + + private static final class ParameterNotFoundException extends Exception { + private static final long serialVersionUID = 0; + + public ParameterNotFoundException(String name, Map map) { + super("didn't find required parameter " + name + " in " + map); + } + } + + private static final class IncorrectTypeException extends Exception { + private static final long serialVersionUID = 0; + + public IncorrectTypeException(String name, Map map, String type) { + super("required parameter " + name + " in " + map + " not " + type); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java new file mode 100644 index 000000000000..fa02a6bf3185 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.BearerToken; +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.auth.oauth2.TokenResponse; + +import java.io.IOException; + +/** + * Fake credential, for use in testing. + */ +public class TestCredential extends Credential { + + private final String token; + + public TestCredential() { + this("NULL"); + } + + public TestCredential(String token) { + super(new Builder( + BearerToken.authorizationHeaderAccessMethod())); + this.token = token; + } + + @Override + protected TokenResponse executeRefreshToken() throws IOException { + TokenResponse response = new TokenResponse(); + response.setExpiresInSeconds(5L * 60); + response.setAccessToken(token); + return response; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java new file mode 100644 index 000000000000..48324818ca63 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import org.joda.time.DateTime; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.ReadableDuration; +import org.joda.time.ReadableInstant; +import org.joda.time.chrono.ISOChronology; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * A helper class for converting between Dataflow API and SDK time + * representations. + *

+ * Dataflow API times are strings of the form + * {@code YYYY-MM-dd'T'HH:mm:ss[.nnnn]'Z'}: that is, RFC 3339 + * strings with optional fractional seconds and a 'Z' offset. + *

+ * Dataflow API durations are strings of the form {@code ['-']sssss[.nnnn]'s'}: + * that is, seconds with optional fractional seconds and a literal 's' at the end. + *

+ * In both formats, fractional seconds are either three digits (millisecond + * resolution), six digits (microsecond resolution), or nine digits (nanosecond + * resolution). + */ +public final class TimeUtil { + private TimeUtil() {} // Non-instantiable. + + private static final Pattern DURATION_PATTERN = Pattern.compile("(\\d+)(?:\\.(\\d+))?s"); + private static final Pattern TIME_PATTERN = + Pattern.compile("(\\d{4})-(\\d{2})-(\\d{2})T(\\d{2}):(\\d{2}):(\\d{2})(?:\\.(\\d+))?Z"); + + /** + * Converts a {@link ReadableInstant} into a Dateflow API time value. + */ + public static String toCloudTime(ReadableInstant instant) { + // Note that since Joda objects use millisecond resolution, we always + // produce either no fractional seconds or fractional seconds with + // millisecond resolution. + + // Translate the ReadableInstant to a DateTime with ISOChronology. + DateTime time = new DateTime(instant); + + int millis = time.getMillisOfSecond(); + if (millis == 0) { + return String.format("%04d-%02d-%02dT%02d:%02d:%02dZ", + time.getYear(), + time.getMonthOfYear(), + time.getDayOfMonth(), + time.getHourOfDay(), + time.getMinuteOfHour(), + time.getSecondOfMinute()); + } else { + return String.format("%04d-%02d-%02dT%02d:%02d:%02d.%03dZ", + time.getYear(), + time.getMonthOfYear(), + time.getDayOfMonth(), + time.getHourOfDay(), + time.getMinuteOfHour(), + time.getSecondOfMinute(), + millis); + } + } + + /** + * Converts a time value received via the Dataflow API into the corresponding + * {@link Instant}. + * @return the parsed time, or null if a parse error occurs + */ + @Nullable + public static Instant fromCloudTime(String time) { + Matcher matcher = TIME_PATTERN.matcher(time); + if (!matcher.matches()) { + return null; + } + int year = Integer.valueOf(matcher.group(1)); + int month = Integer.valueOf(matcher.group(2)); + int day = Integer.valueOf(matcher.group(3)); + int hour = Integer.valueOf(matcher.group(4)); + int minute = Integer.valueOf(matcher.group(5)); + int second = Integer.valueOf(matcher.group(6)); + int millis = 0; + + String frac = matcher.group(7); + if (frac != null) { + int fracs = Integer.valueOf(frac); + if (frac.length() == 3) { // millisecond resolution + millis = fracs; + } else if (frac.length() == 6) { // microsecond resolution + millis = fracs / 1000; + } else if (frac.length() == 9) { // nanosecond resolution + millis = fracs / 1000000; + } else { + return null; + } + } + + return new DateTime(year, month, day, hour, minute, second, millis, + ISOChronology.getInstanceUTC()).toInstant(); + } + + /** + * Converts a {@link ReadableDuration} into a Dataflow API duration string. + */ + public static String toCloudDuration(ReadableDuration duration) { + // Note that since Joda objects use millisecond resolution, we always + // produce either no fractional seconds or fractional seconds with + // millisecond resolution. + long millis = duration.getMillis(); + long seconds = millis / 1000; + millis = millis % 1000; + if (millis == 0) { + return String.format("%ds", seconds); + } else { + return String.format("%d.%03ds", seconds, millis); + } + } + + /** + * Converts a Dataflow API duration string into a {@link Duration}. + * @return the parsed duration, or null if a parse error occurs + */ + @Nullable + public static Duration fromCloudDuration(String duration) { + Matcher matcher = DURATION_PATTERN.matcher(duration); + if (!matcher.matches()) { + return null; + } + long millis = Long.valueOf(matcher.group(1)) * 1000; + String frac = matcher.group(2); + if (frac != null) { + long fracs = Long.valueOf(frac); + if (frac.length() == 3) { // millisecond resolution + millis += fracs; + } else if (frac.length() == 6) { // microsecond resolution + millis += fracs / 1000; + } else if (frac.length() == 9) { // nanosecond resolution + millis += fracs / 1000000; + } else { + return null; + } + } + return Duration.millis(millis); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerOrElement.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerOrElement.java new file mode 100644 index 000000000000..4859f8ae5f39 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerOrElement.java @@ -0,0 +1,195 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; + +/** + * Class representing either a timer, or arbitrary element. + * Used as the input type of {@link StreamingGroupAlsoByWindowsDoFn}. + * + * @param the element type + */ +public class TimerOrElement { + + /** + * Creates a new {@code TimerOrElement} representing a timer. + * + * @param the element type + */ + public static TimerOrElement timer( + String tag, Instant timestamp, Object key) { + return new TimerOrElement<>(tag, timestamp, key); + } + + /** + * Creates a new {@code TimerOrElement} representing an element. + * + * @param the element type + */ + public static TimerOrElement element(E element) { + return new TimerOrElement<>(element); + } + + /** + * Returns whether this is a timer or an element. + */ + public boolean isTimer() { + return isTimer; + } + + /** + * If this is a timer, returns its tag, otherwise throws an exception. + */ + public String tag() { + if (!isTimer) { + throw new IllegalStateException("tag() called, but this is an element"); + } + return tag; + } + + /** + * If this is a timer, returns its timestamp, otherwise throws an exception. + */ + public Instant timestamp() { + if (!isTimer) { + throw new IllegalStateException("timestamp() called, but this is an element"); + } + return timestamp; + } + + /** + * If this is a timer, returns its key, otherwise throws an exception. + */ + public Object key() { + if (!isTimer) { + throw new IllegalStateException("key() called, but this is an element"); + } + return key; + } + + /** + * If this is an element, returns it, otherwise throws an exception. + */ + public E element() { + if (isTimer) { + throw new IllegalStateException("element() called, but this is a timer"); + } + return element; + } + + /** + * Coder that forwards {@code ByteSizeObserver} calls to an underlying element coder. + * {@code TimerOrElement} objects never need to be encoded, so this class does not + * support the {@code encode} and {@code decode} methods. + */ + public static class TimerOrElementCoder extends StandardCoder> { + final Coder elemCoder; + + /** + * Creates a new {@code TimerOrElement.Coder} that wraps the given {@link Coder}. + */ + public static TimerOrElementCoder of(Coder elemCoder) { + return new TimerOrElementCoder<>(elemCoder); + } + + @JsonCreator + public static TimerOrElementCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + return of((Coder) components.get(0)); + } + + @Override + public void encode(TimerOrElement value, OutputStream outStream, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public TimerOrElement decode(InputStream inStream, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRegisterByteSizeObserverCheap(TimerOrElement value, Context context) { + if (value.isTimer()) { + return true; + } else { + return elemCoder.isRegisterByteSizeObserverCheap(value.element(), context); + } + } + + @Override + public void registerByteSizeObserver( + TimerOrElement value, ElementByteSizeObserver observer, Context context) + throws Exception{ + if (!value.isTimer()) { + elemCoder.registerByteSizeObserver(value.element(), observer, context); + } + } + + @Override + public boolean isDeterministic() { + return elemCoder.isDeterministic(); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(elemCoder); + } + + public Coder getElementCoder() { + return elemCoder; + } + + private TimerOrElementCoder(Coder elemCoder) { + this.elemCoder = elemCoder; + } + } + + ////////////////////////////////////////////////////////////////////////////// + + private boolean isTimer; + private String tag; + private Instant timestamp; + private Object key; + private E element; + + TimerOrElement(String tag, Instant timestamp, Object key) { + this.isTimer = true; + this.tag = tag; + this.timestamp = timestamp; + this.key = key; + } + + TimerOrElement(E element) { + this.isTimer = false; + this.element = element; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java new file mode 100644 index 000000000000..e27f7fcc4f88 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java @@ -0,0 +1,141 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.storage.Storage; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.security.GeneralSecurityException; + +/** + * Helpers for cloud communication. + */ +public class Transport { + + private static class SingletonHelper { + /** Global instance of the JSON factory. */ + private static final JsonFactory JSON_FACTORY; + + /** Global instance of the HTTP transport. */ + private static final HttpTransport HTTP_TRANSPORT; + + static { + try { + JSON_FACTORY = JacksonFactory.getDefaultInstance(); + HTTP_TRANSPORT = GoogleNetHttpTransport.newTrustedTransport(); + } catch (GeneralSecurityException | IOException e) { + throw new RuntimeException(e); + } + } + } + + public static HttpTransport getTransport() { + return SingletonHelper.HTTP_TRANSPORT; + } + + public static JsonFactory getJsonFactory() { + return SingletonHelper.JSON_FACTORY; + } + + /** + * Returns a BigQuery client builder. + *

+ * Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Bigquery.Builder + newBigQueryClient(BigQueryOptions options) { + return new Bigquery.Builder(getTransport(), getJsonFactory(), + new RetryHttpRequestInitializer(options.getGcpCredential())) + .setApplicationName(options.getAppName()); + } + +/** + * Returns a Pubsub client builder. + *

+ * Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Pubsub.Builder + newPubsubClient(StreamingOptions options) { + return new Pubsub.Builder(getTransport(), getJsonFactory(), + new RetryHttpRequestInitializer(options.getGcpCredential())) + .setApplicationName(options.getAppName()); + } + + /** + * Returns a Google Cloud Dataflow client builder. + */ + public static Dataflow.Builder newDataflowClient(DataflowPipelineOptions options) { + String rootUrl = options.getApiRootUrl(); + String servicePath = options.getDataflowEndpoint(); + if (servicePath.contains("://")) { + try { + URL url = new URL(servicePath); + rootUrl = url.getProtocol() + "://" + url.getHost() + + (url.getPort() > 0 ? ":" + url.getPort() : ""); + servicePath = url.getPath(); + } catch (MalformedURLException e) { + throw new RuntimeException("Invalid URL: " + servicePath); + } + } + + return new Dataflow.Builder(getTransport(), + getJsonFactory(), + new RetryHttpRequestInitializer(options.getGcpCredential())) + .setApplicationName(options.getAppName()) + .setRootUrl(rootUrl) + .setServicePath(servicePath); + } + + /** + * Returns a Dataflow client which does not automatically retry failed + * requests. + */ + public static Dataflow.Builder + newRawDataflowClient(DataflowPipelineOptions options) { + return newDataflowClient(options) + .setHttpRequestInitializer(options.getGcpCredential()); + } + + /** + * Returns a Cloud Storage client builder. + *

+ * Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Storage.Builder + newStorageClient(GcsOptions options) { + return new Storage.Builder(getTransport(), getJsonFactory(), + new RetryHttpRequestInitializer(options.getGcpCredential())) + .setApplicationName(options.getAppName()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java new file mode 100644 index 000000000000..a0bfed1626f9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Objects; + +/** + * An exception that was thrown in user-code. Sets the stack trace + * from the first time execution enters user code down through the + * rest of the user's stack frames until the exception is + * reached. + */ +public class UserCodeException extends RuntimeException { + private static final Logger LOG = LoggerFactory.getLogger(UserCodeException.class); + + public UserCodeException(Throwable t) { + super(t); + + StackTraceElement[] currentFrames = + Thread.currentThread().getStackTrace(); + + // We're interested in getting the third stack frame here, since + // the exception stack trace includes the getStackTrace frame from + // Thread and the frame from where the UserCodeException is + // actually thrown. If there aren't more than two frames, + // something is odd about where the exception was thrown, so leave + // the stack trace alone and allow it to propagate. + // + // For example, if an exception in user code has a stack trace like this: + // + // java.lang.NullPointerException + // at com.google.cloud.dataflow.sdk.examples. + // SimpleWordCount$ExtractWordsFn.dieHere(SimpleWordCount.java:23) + // at com.google.cloud.dataflow.sdk.examples. + // SimpleWordCount$ExtractWordsFn. + // processElement(SimpleWordCount.java:27) + // at com.google.cloud.dataflow.sdk. + // DoFnRunner.processElement(DoFnRunner.java:95) <-- caught here + // at com.google.cloud.dataflow.sdk. + // worker.NormalParDoFn.processElement(NormalParDoFn.java:119) + // at com.google.cloud.dataflow.sdk. + // worker.executor.ParDoOperation.process(ParDoOperation.java:65) + // at com.google.cloud.dataflow.sdk. + // worker.executor.ReadOperation.start(ReadOperation.java:65) + // at com.google.cloud.dataflow.sdk. + // worker.executor.MapTaskExecutor.execute(MapTaskExecutor.java:79) + // at com.google.cloud.dataflow.sdk. + // worker.DataflowWorkerHarness.main(DataflowWorkerHarness.java:95) + // + // It would be truncated to: + // + // java.lang.NullPointerException + // at com.google.cloud.dataflow.sdk.examples. + // SimpleWordCount$ExtractWordsFn.dieHere(SimpleWordCount.java:23) + // at com.google.cloud.dataflow.sdk.examples. + // SimpleWordCount$ExtractWordsFn. + // processElement(SimpleWordCount.java:27) + // + // However, we need to get the third stack frame from the + // getStackTrace, since after catching the error in DoFnRunner, + // the trace is two frames deeper by the time we get it: + // + // [0] java.lang.Thread.getStackTrace(Thread.java:1568) + // [1] com.google.cloud.dataflow.sdk. + // UserCodeException.(UserCodeException.java:16) + // [2] com.google.cloud.dataflow.sdk. + // DoFnRunner.processElement(DoFnRunner.java:95) <-- common frame + // + // We then proceed to truncate the original exception at the + // common frame, setting the UserCodeException's cause to the + // truncated stack trace. + + // Check to make sure the stack is > 2 deep. + if (currentFrames.length <= 2) { + LOG.error("Expecting stack trace to be > 2 frames long."); + return; + } + + // Perform some checks to make sure javac doesn't change from below us. + if (!Objects.equals(currentFrames[1].getClassName(), getClass().getName())) { + LOG.error("Expected second frame coming from Thread.currentThread.getStackTrace() " + + "to be {}, was: {}", getClass().getName(), currentFrames[1].getClassName()); + return; + } + if (Objects.equals(currentFrames[2].getClassName(), currentFrames[1].getClassName())) { + LOG.error("Javac's Thread.CurrentThread.getStackTrace() changed unexpectedly."); + return; + } + + // Now that all checks have passed, select the common frame. + StackTraceElement callingFrame = currentFrames[2]; + // Truncate the user-level stack trace below where the + // UserCodeException was thrown. + truncateStackTrace(callingFrame, t); + } + + /** + * Truncates this Throwable's stack frame at the given frame, + * removing all frames below. + */ + private void truncateStackTrace( + StackTraceElement currentFrame, Throwable t) { + int index = 0; + StackTraceElement[] stackTrace = t.getStackTrace(); + for (StackTraceElement element : stackTrace) { + if (Objects.equals(element.getClassName(), currentFrame.getClassName()) && + Objects.equals(element.getMethodName(), currentFrame.getMethodName())) { + t.setStackTrace(Arrays.copyOfRange(stackTrace, 0, index)); + break; + } + index++; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java new file mode 100644 index 000000000000..f5ce4540d931 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A collection of static methods for manipulating value representations + * transfered via the Dataflow API. + */ +public final class Values { + private Values() {} // Non-instantiable + + public static Boolean asBoolean(Object value) throws ClassCastException { + @Nullable Boolean knownResult = checkKnownValue(CloudKnownType.BOOLEAN, value, Boolean.class); + if (knownResult != null) { + return knownResult; + } + return Boolean.class.cast(value); + } + + public static Double asDouble(Object value) throws ClassCastException { + @Nullable Double knownResult = checkKnownValue(CloudKnownType.FLOAT, value, Double.class); + if (knownResult != null) { + return knownResult; + } + if (value instanceof Double) { + return (Double) value; + } + return ((Float) value).doubleValue(); + } + + public static Long asLong(Object value) throws ClassCastException { + @Nullable Long knownResult = checkKnownValue(CloudKnownType.INTEGER, value, Long.class); + if (knownResult != null) { + return knownResult; + } + if (value instanceof Long) { + return (Long) value; + } + return ((Integer) value).longValue(); + } + + public static String asString(Object value) throws ClassCastException { + @Nullable String knownResult = checkKnownValue(CloudKnownType.TEXT, value, String.class); + if (knownResult != null) { + return knownResult; + } + return String.class.cast(value); + } + + @Nullable + private static T checkKnownValue(CloudKnownType type, Object value, Class clazz) { + if (!(value instanceof Map)) { + return null; + } + Map map = (Map) value; + @Nullable String typeName = (String) map.get(PropertyNames.OBJECT_TYPE_NAME); + if (typeName == null) { + return null; + } + @Nullable CloudKnownType knownType = CloudKnownType.forUri(typeName); + if (knownType == null || knownType != type) { + return null; + } + @Nullable Object scalar = map.get(PropertyNames.SCALAR_FIELD_NAME); + if (scalar == null) { + return null; + } + return knownType.parse(scalar, clazz); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java new file mode 100644 index 000000000000..a7399473d4b4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Variable-length encoding for integers. + * + * Handles, in a common encoding format, signed bytes, shorts, ints, and longs. + * Takes between 1 and 10 bytes. + * Less efficient than BigEndian{Int,Long} coder for negative or large numbers. + * All negative ints are encoded using 5 bytes, longs take 10 bytes. + */ +public class VarInt { + + private static long convertIntToLongNoSignExtend(int v) { + return ((long) v) & 0xFFFFFFFFL; + } + + /** + * Encodes the given value onto the stream. + */ + public static void encode(int v, OutputStream stream) throws IOException { + encode(convertIntToLongNoSignExtend(v), stream); + } + + /** + * Encodes the given value onto the stream. + */ + public static void encode(long v, OutputStream stream) throws IOException { + do { + // Encode next 7 bits + terminator bit + long bits = v & 0x7F; + v >>>= 7; + byte b = (byte) (bits | ((v != 0) ? 0x80 : 0)); + stream.write(b); + } while (v != 0); + } + + /** + * Decodes an integer value from the given stream. + */ + public static int decodeInt(InputStream stream) throws IOException { + long r = decodeLong(stream); + if (r < 0 || r >= 1L << 32) { + throw new IOException("varint overflow " + r); + } + return (int) r; + } + + /** + * Decodes a long value from the given stream. + */ + public static long decodeLong(InputStream stream) throws IOException { + long result = 0; + int shift = 0; + int b; + do { + // Get 7 bits from next byte + b = stream.read(); + if (b < 0) { + if (shift == 0) { + throw new EOFException(); + } else { + throw new IOException("varint not terminated"); + } + } + long bits = b & 0x7F; + if (shift >= 64 || (shift == 63 && bits > 1)) { + // Out of range + throw new IOException("varint too long"); + } + result |= bits << shift; + shift += 7; + } while ((b & 0x80) != 0); + return result; + } + + /** + * Returns the length of the encoding of the given value (in bytes). + */ + public static int getLength(int v) { + return getLength(convertIntToLongNoSignExtend(v)); + } + + /** + * Returns the length of the encoding of the given value (in bytes). + */ + public static int getLength(long v) { + int result = 0; + do { + result++; + v >>>= 7; + } while (v != 0); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowUtils.java new file mode 100644 index 000000000000..de0a8f24ba64 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowUtils.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Base64; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.CodedTupleTag; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** + * Utility functions related to serializing windows. + */ +class WindowUtils { + private static final String BUFFER_TAG_PREFIX = "buffer:"; + + /** + * Converts the given window to a base64-encoded String using the given coder. + */ + public static String windowToString(W window, Coder coder) throws IOException { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + coder.encode(window, stream, Coder.Context.OUTER); + byte[] rawWindow = stream.toByteArray(); + return Base64.encodeBase64String(rawWindow); + } + + /** + * Parses a window from a base64-encoded String using the given coder. + */ + public static W windowFromString(String serializedWindow, Coder coder) throws IOException { + return coder.decode( + new ByteArrayInputStream(Base64.decodeBase64(serializedWindow)), + Coder.Context.OUTER); + } + + /** + * Returns a tag for storing buffered data in per-key state. + */ + public static CodedTupleTag bufferTag( + W window, Coder windowCoder, Coder elemCoder) + throws IOException { + return CodedTupleTag.of( + BUFFER_TAG_PREFIX + windowToString(window, windowCoder), elemCoder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java new file mode 100644 index 000000000000..de310b827114 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java @@ -0,0 +1,368 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CollectionCoder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * An immutable triple of value, timestamp, and windows. + * + * @param the type of the value + */ +public class WindowedValue { + + private final V value; + private final Instant timestamp; + private final Collection windows; + + /** + * Returns a {@code WindowedValue} with the given value, timestamp, and windows. + */ + public static WindowedValue of( + V value, + Instant timestamp, + Collection windows) { + return new WindowedValue<>(value, timestamp, windows); + } + + /** + * Returns a {@code WindowedValue} with the given value, default timestamp, + * and {@code GlobalWindow}. + */ + public static WindowedValue valueInGlobalWindow(V value) { + return new WindowedValue<>(value, + new Instant(Long.MIN_VALUE), + Arrays.asList(GlobalWindow.Window.INSTANCE)); + } + + /** + * Returns a {@code WindowedValue} with the given value and default timestamp and empty windows. + */ + public static WindowedValue valueInEmptyWindows(V value) { + return new WindowedValue<>(value, + new Instant(Long.MIN_VALUE), + new ArrayList()); + } + + private WindowedValue(V value, + Instant timestamp, + Collection windows) { + this.value = value; + this.timestamp = timestamp; + this.windows = windows; + } + + /** + * Returns a new {@code WindowedValue} that is a copy of this one, but with a different value. + */ + public WindowedValue withValue(V value) { + return new WindowedValue<>(value, this.timestamp, this.windows); + } + + /** + * Returns the value of this {@code WindowedValue}. + */ + public V getValue() { + return value; + } + + /** + * Returns the timestamp of this {@code WindowedValue}. + */ + public Instant getTimestamp() { + return timestamp; + } + + /** + * Returns the windows of this {@code WindowedValue}. + */ + public Collection getWindows() { + return windows; + } + + /** + * Returns the {@code Coder} to use for a {@code WindowedValue}, + * using the given valueCoder and windowCoder. + */ + public static WindowedValueCoder getFullCoder( + Coder valueCoder, + Coder windowCoder) { + return FullWindowedValueCoder.of(valueCoder, windowCoder); + } + + /** + * Returns the {@code ValueOnlyCoder} from the given valueCoder. + */ + public static WindowedValueCoder getValueOnlyCoder(Coder valueCoder) { + return ValueOnlyWindowedValueCoder.of(valueCoder); + } + + @Override + public boolean equals(Object o) { + if (o instanceof WindowedValue) { + WindowedValue that = (WindowedValue) o; + if (that.timestamp.isEqual(timestamp) && that.windows.size() == windows.size()) { + for (Iterator thatIterator = that.windows.iterator(), thisIterator = windows.iterator(); + thatIterator.hasNext() && thisIterator.hasNext(); + /* do nothng */) { + if (!thatIterator.next().equals(thisIterator.next())) { + return false; + } + } + return true; + } + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(value, timestamp, Arrays.hashCode(windows.toArray())); + } + + @Override + public String toString() { + return "[WindowedValue: " + value + ", timestamp: " + timestamp.getMillis() + + ", windows: " + windows + "]"; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Abstract class for {@code WindowedValue} coder. + */ + public abstract static class WindowedValueCoder + extends StandardCoder> { + final Coder valueCoder; + + WindowedValueCoder(Coder valueCoder) { + this.valueCoder = checkNotNull(valueCoder); + } + + /** + * Returns the value coder. + */ + public Coder getValueCoder() { + return valueCoder; + } + + /** + * Returns a new {@code WindowedValueCoder} that is a copy of this one, + * but with a different value coder. + */ + public abstract WindowedValueCoder withValueCoder(Coder valueCoder); + } + + /** + * Coder for {@code WindowedValue}. + */ + public static class FullWindowedValueCoder extends WindowedValueCoder { + private final Coder windowCoder; + // Precompute and cache the coder for a list of windows. + private final Coder> windowsCoder; + + public static FullWindowedValueCoder of( + Coder valueCoder, + Coder windowCoder) { + return new FullWindowedValueCoder<>(valueCoder, windowCoder); + } + + @JsonCreator + public static FullWindowedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + return of(components.get(0), + (Coder) components.get(1)); + } + + @SuppressWarnings("unchecked") + FullWindowedValueCoder(Coder valueCoder, + Coder windowCoder) { + super(valueCoder); + this.windowCoder = checkNotNull(windowCoder); + // It's not possible to statically type-check correct use of the + // windowCoder (we have to ensure externally that we only get + // windows of the class handled by windowCoder), so type + // windowsCoder in a way that makes encode() and decode() work + // right, and cast the window type away here. + this.windowsCoder = (Coder) CollectionCoder.of(this.windowCoder); + } + + public Coder getWindowCoder() { + return windowCoder; + } + + public Coder> getWindowsCoder() { + return windowsCoder; + } + + @Override + public WindowedValueCoder withValueCoder(Coder valueCoder) { + return new FullWindowedValueCoder<>(valueCoder, windowCoder); + } + + @Override + public void encode(WindowedValue windowedElem, + OutputStream outStream, + Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + valueCoder.encode(windowedElem.getValue(), outStream, nestedContext); + InstantCoder.of().encode( + windowedElem.getTimestamp(), outStream, nestedContext); + windowsCoder.encode(windowedElem.getWindows(), outStream, nestedContext); + } + + @Override + public WindowedValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + T value = valueCoder.decode(inStream, nestedContext); + Instant timestamp = InstantCoder.of().decode(inStream, nestedContext); + Collection windows = + windowsCoder.decode(inStream, nestedContext); + return WindowedValue.of(value, timestamp, windows); + } + + @Override + public boolean isDeterministic() { + return valueCoder.isDeterministic() && windowCoder.isDeterministic(); + } + + @Override + public void registerByteSizeObserver(WindowedValue value, + ElementByteSizeObserver observer, + Context context) throws Exception { + valueCoder.registerByteSizeObserver(value.getValue(), observer, context); + InstantCoder.of().registerByteSizeObserver(value.getTimestamp(), observer, context); + windowsCoder.registerByteSizeObserver(value.getWindows(), observer, context); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_WRAPPER, true); + return result; + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return Arrays.>asList(valueCoder, windowCoder); + } + } + + /** + * Coder for {@code WindowedValue}. + * + *

A {@code ValueOnlyWindowedValueCoder} only encodes and decodes the value. It drops + * timestamp and windows for encoding, and uses defaults timestamp, and windows for decoding. + */ + public static class ValueOnlyWindowedValueCoder extends WindowedValueCoder { + + public static ValueOnlyWindowedValueCoder of( + Coder valueCoder) { + return new ValueOnlyWindowedValueCoder<>(valueCoder); + } + + @JsonCreator + public static ValueOnlyWindowedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + checkArgument(components.size() == 1, "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + ValueOnlyWindowedValueCoder(Coder valueCoder) { + super(valueCoder); + } + + @Override + public WindowedValueCoder withValueCoder(Coder valueCoder) { + return new ValueOnlyWindowedValueCoder<>(valueCoder); + } + + @Override + public void encode(WindowedValue windowedElem, OutputStream outStream, Context context) + throws CoderException, IOException { + valueCoder.encode(windowedElem.getValue(), outStream, context); + } + + @Override + public WindowedValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + T value = valueCoder.decode(inStream, context); + return WindowedValue.valueInGlobalWindow(value); + } + + @Override + public boolean isDeterministic() { + return valueCoder.isDeterministic(); + } + + @Override + public void registerByteSizeObserver( + WindowedValue value, ElementByteSizeObserver observer, Context context) + throws Exception { + valueCoder.registerByteSizeObserver(value.getValue(), observer, context); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_WRAPPER, true); + return result; + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java new file mode 100644 index 000000000000..8b5f636ac5da --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java @@ -0,0 +1,730 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.AND; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.OR; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SET; + +import com.google.common.reflect.TypeToken; + +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; +import java.util.logging.Logger; + +/** + * A Counter enables the aggregation of a stream of values over time. The + * cumulative aggregate value is updated as new values are added, or it can be + * reset to a new value. Multiple kinds of aggregation are supported depending + * on the type of the counter. + * + *

Counters compare using value equality of their name, kind, and + * cumulative value. Equal counters should have equal toString()s. + * + * @param the type of values aggregated by this counter + */ +public abstract class Counter { + private static final Logger LOG = Logger.getLogger(Counter.class.getName()); + + /** + * Possible kinds of counter aggregation. + */ + public static enum AggregationKind { + + /** + * Computes the sum of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + SUM, + + /** + * Computes the maximum value of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + MAX, + + /** + * Computes the minimum value of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + MIN, + + /** + * Computes the arithmetic mean of all added values. Applicable to + * {@link Integer}, {@link Long}, and {@link Double} values. + */ + MEAN, + + /** + * Computes the set of all added values. Applicable to {@link Integer}, + * {@link Long}, {@link Double}, and {@link String} values. + */ + SET, + + /** + * Computes boolean AND over all added values. + * Applicable only to {@link Boolean} values. + */ + AND, + + /** + * Computes boolean OR over all added values. Applicable only to + * {@link Boolean} values. + */ + OR + // TODO: consider adding VECTOR_SUM, HISTOGRAM, KV_SET, PRODUCT, TOP. + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Integer}, values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, {@link AggregationKind#MEAN}, and + * {@link AggregationKind#SET}. This is a convenience wrapper over a + * {@link Counter} implementation that aggregates {@link Long} values. This is + * useful when the application handles (boxed) {@link Integer} values which + * are not readily convertible to the (boxed) {@link Long} values otherwise + * expected by the {@link Counter} implementation aggregating {@link Long} + * values. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter ints(String name, AggregationKind kind) { + return new IntegerCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Long} values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, {@link AggregationKind#MEAN}, and + * {@link AggregationKind#SET}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter longs(String name, AggregationKind kind) { + return new LongCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Double} values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, {@link AggregationKind#MEAN}, and + * {@link AggregationKind#SET}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter doubles(String name, AggregationKind kind) { + return new DoubleCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Boolean} values + * according to the desired aggregation kind. The only supported aggregation + * kinds are {@link AggregationKind#AND} and {@link AggregationKind#OR}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter booleans(String name, AggregationKind kind) { + return new BooleanCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link String} values + * according to the desired aggregation kind. The only supported aggregation + * kind is {@link AggregationKind#SET}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter strings(String name, AggregationKind kind) { + return new StringCounter(name, kind); + } + + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Adds a new value to the aggregation stream. Returns this (to allow method + * chaining). + */ + public abstract Counter addValue(T value); + + /** + * Resets the aggregation stream to this new value. Returns this (to allow + * method chaining). + */ + public Counter resetToValue(T value) { + return resetToValue(-1, value); + } + + /** + * Resets the aggregation stream to this new value. Returns this (to allow + * method chaining). The value of elementCount must be -1 for non-MEAN + * aggregations. The value of elementCount must be non-negative for MEAN + * aggregation. + */ + public synchronized Counter resetToValue(long elementCount, T value) { + aggregate = value; + deltaAggregate = value; + + if (kind.equals(MEAN)) { + if (elementCount < 0) { + throw new AssertionError( + "elementCount must be non-negative for MEAN aggregation"); + } + count = elementCount; + deltaCount = elementCount; + } else { + if (elementCount != -1) { + throw new AssertionError( + "elementCount must be -1 for non-MEAN aggregations"); + } + count = 0; + deltaCount = 0; + } + + if (kind.equals(SET)) { + set.clear(); + set.add(value); + deltaSet = new HashSet<>(); + deltaSet.add(value); + } + return this; + } + + /** Resets the counter's delta value to have no values accumulated. */ + public abstract void resetDelta(); + + /** + * Returns the counter's name. + */ + public String getName() { + return name; + } + + /** + * Returns the counter's aggregation kind. + */ + public AggregationKind getKind() { + return kind; + } + + /** + * Returns the counter's type. + */ + public Class getType() { + return new TypeToken(getClass()) {}.getRawType(); + } + + /** + * Returns the aggregated value, or the sum for MEAN aggregation, either + * total or, if delta, since the last update extraction or resetDelta, + * if not a SET aggregation. + */ + public T getAggregate(boolean delta) { + return delta ? deltaAggregate : aggregate; + } + + /** + * Returns the number of aggregated values, either total or, if + * delta, since the last update extraction or resetDelta, if a MEAN + * aggregation. + */ + public long getCount(boolean delta) { + return delta ? deltaCount : count; + } + + /** + * Returns the set of all aggregated values, either total or, if + * delta, since the last update extraction or resetDelta, if a SET + * aggregation. + */ + public Set getSet(boolean delta) { + return delta ? deltaSet : set; + } + + /** + * Returns a string representation of the Counter. Useful for debugging logs. + * Example return value: "ElementCount:SUM(15)". + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getName()); + sb.append(":"); + sb.append(getKind()); + sb.append("("); + switch (kind) { + case SUM: + case MAX: + case MIN: + case AND: + case OR: + sb.append(aggregate); + break; + case MEAN: + sb.append(aggregate); + sb.append("/"); + sb.append(count); + break; + case SET: + sb.append(set); + break; + default: + throw illegalArgumentException(); + } + sb.append(")"); + + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof Counter) { + Counter that = (Counter) o; + return this.name.equals(that.name) + && this.kind == that.kind + && this.getClass().equals(that.getClass()) + && this.count == that.count + && Objects.equals(this.aggregate, that.aggregate) + && Objects.equals(this.set, that.set); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), name, kind, aggregate, count, set); + } + + /** + * Returns whether this Counter is compatible with that Counter. If + * so, they can be merged into a single Counter. + */ + public boolean isCompatibleWith(Counter that) { + return this.name.equals(that.name) + && this.kind == that.kind + && this.getClass().equals(that.getClass()); + } + + + ////////////////////////////////////////////////////////////////////////////// + + /** The name of this counter. */ + protected final String name; + + /** The kind of aggregation function to apply to this counter. */ + protected final AggregationKind kind; + + /** The total cumulative aggregation value. Holds sum for MEAN aggregation. */ + protected T aggregate; + + /** The cumulative aggregation value since the last update extraction. */ + protected T deltaAggregate; + + /** The total number of aggregated values. Useful for MEAN aggregation. */ + protected long count; + + /** The number of aggregated values since the last update extraction. */ + protected long deltaCount; + + /** Holds the set of all aggregated values. Used only for SET aggregation. */ + protected Set set; + + /** Holds the set of aggregated values since the last update extraction. */ + protected Set deltaSet; + + protected Counter(String name, AggregationKind kind) { + this.name = name; + this.kind = kind; + this.count = 0; + this.deltaCount = 0; + if (kind.equals(SET)) { + set = new HashSet<>(); + deltaSet = new HashSet<>(); + } + } + + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Implements a {@link Counter} for {@link Long} values. + */ + private static class LongCounter extends Counter { + + /** Initializes a new {@link Counter} for {@link Long} values. */ + private LongCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case SUM: + case MEAN: + aggregate = deltaAggregate = 0L; + break; + case MAX: + aggregate = deltaAggregate = Long.MIN_VALUE; + break; + case MIN: + aggregate = deltaAggregate = Long.MAX_VALUE; + break; + case SET: + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public synchronized LongCounter addValue(Long value) { + switch (kind) { + case SUM: + aggregate += value; + deltaAggregate += value; + break; + case MEAN: + aggregate += value; + deltaAggregate += value; + count++; + deltaCount++; + break; + case MAX: + aggregate = Math.max(aggregate, value); + deltaAggregate = Math.max(deltaAggregate, value); + break; + case MIN: + aggregate = Math.min(aggregate, value); + deltaAggregate = Math.min(deltaAggregate, value); + break; + case SET: + set.add(value); + deltaSet.add(value); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + @Override + public synchronized void resetDelta() { + switch (kind) { + case SUM: + deltaAggregate = 0L; + break; + case MEAN: + deltaAggregate = 0L; + deltaCount = 0; + break; + case MAX: + deltaAggregate = Long.MIN_VALUE; + break; + case MIN: + deltaAggregate = Long.MAX_VALUE; + break; + case SET: + deltaSet = new HashSet<>(); + break; + default: + throw illegalArgumentException(); + } + } + } + + /** + * Implements a {@link Counter} for {@link Double} values. + */ + private static class DoubleCounter extends Counter { + + /** Initializes a new {@link Counter} for {@link Double} values. */ + private DoubleCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case SUM: + case MEAN: + aggregate = deltaAggregate = 0.0; + break; + case MAX: + aggregate = deltaAggregate = Double.MIN_VALUE; + break; + case MIN: + aggregate = deltaAggregate = Double.MAX_VALUE; + break; + case SET: + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public synchronized DoubleCounter addValue(Double value) { + switch (kind) { + case SUM: + aggregate += value; + deltaAggregate += value; + break; + case MEAN: + aggregate += value; + deltaAggregate += value; + count++; + deltaCount++; + break; + case MAX: + aggregate = Math.max(aggregate, value); + deltaAggregate = Math.max(deltaAggregate, value); + break; + case MIN: + aggregate = Math.min(aggregate, value); + deltaAggregate = Math.min(deltaAggregate, value); + break; + case SET: + set.add(value); + deltaSet.add(value); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + @Override + public synchronized void resetDelta() { + switch (kind) { + case SUM: + deltaAggregate = 0.0; + break; + case MEAN: + deltaAggregate = 0.0; + deltaCount = 0; + break; + case MAX: + deltaAggregate = Double.MIN_VALUE; + break; + case MIN: + deltaAggregate = Double.MAX_VALUE; + break; + case SET: + deltaSet = new HashSet<>(); + break; + default: + throw illegalArgumentException(); + } + } + } + + /** + * Implements a {@link Counter} for {@link Boolean} values. + */ + private static class BooleanCounter extends Counter { + + /** Initializes a new {@link Counter} for {@link Boolean} values. */ + private BooleanCounter(String name, AggregationKind kind) { + super(name, kind); + if (kind.equals(AND)) { + aggregate = deltaAggregate = true; + } else if (kind.equals(OR)) { + aggregate = deltaAggregate = false; + } else { + throw illegalArgumentException(); + } + } + + @Override + public synchronized BooleanCounter addValue(Boolean value) { + if (kind.equals(AND)) { + aggregate &= value; + deltaAggregate &= value; + } else { // kind.equals(OR)) + aggregate |= value; + deltaAggregate |= value; + } + return this; + } + + @Override + public synchronized void resetDelta() { + switch (kind) { + case AND: + deltaAggregate = true; + break; + case OR: + deltaAggregate = false; + break; + default: + throw illegalArgumentException(); + } + } + } + + /** + * Implements a {@link Counter} for {@link String} values. + */ + private static class StringCounter extends Counter { + + /** Initializes a new {@link Counter} for {@link String} values. */ + private StringCounter(String name, AggregationKind kind) { + super(name, kind); + if (!kind.equals(SET)) { + throw illegalArgumentException(); + } + } + + @Override + public synchronized StringCounter addValue(String value) { + set.add(value); + deltaSet.add(value); + return this; + } + + @Override + public synchronized void resetDelta() { + switch (kind) { + case SET: + deltaSet = new HashSet<>(); + break; + default: + throw illegalArgumentException(); + } + } + } + + /** + * Implements a {@link Counter} for {@link Integer} values. + */ + private static class IntegerCounter extends Counter { + + /** Initializes a new {@link Counter} for {@link Integer} values. */ + private IntegerCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case SUM: + case MEAN: + aggregate = deltaAggregate = 0; + break; + case MAX: + aggregate = deltaAggregate = Integer.MIN_VALUE; + break; + case MIN: + aggregate = deltaAggregate = Integer.MAX_VALUE; + break; + case SET: + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public synchronized IntegerCounter addValue(Integer value) { + switch (kind) { + case SUM: + aggregate += value; + deltaAggregate += value; + break; + case MEAN: + aggregate += value; + deltaAggregate += value; + count++; + deltaCount++; + break; + case MAX: + aggregate = Math.max(aggregate, value); + deltaAggregate = Math.max(deltaAggregate, value); + break; + case MIN: + aggregate = Math.min(aggregate, value); + deltaAggregate = Math.min(deltaAggregate, value); + break; + case SET: + set.add(value); + deltaSet.add(value); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + @Override + public synchronized void resetDelta() { + switch (kind) { + case SUM: + deltaAggregate = 0; + break; + case MEAN: + deltaAggregate = 0; + deltaCount = 0; + break; + case MAX: + deltaAggregate = Integer.MIN_VALUE; + break; + case MIN: + deltaAggregate = Integer.MAX_VALUE; + break; + case SET: + deltaSet = new HashSet<>(); + break; + default: + throw illegalArgumentException(); + } + } + } + + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Constructs an {@link IllegalArgumentException} explaining that this + * {@link Counter}'s aggregation kind is not supported by its value type. + */ + protected IllegalArgumentException illegalArgumentException() { + return new IllegalArgumentException("Cannot compute " + kind + + " aggregation over " + getType().getSimpleName() + " values."); + } + + + ////////////////////////////////////////////////////////////////////////////// + + // For testing. + synchronized T getTotalAggregate() { return aggregate; } + synchronized T getDeltaAggregate() { return deltaAggregate; } + synchronized long getTotalCount() { return count; } + synchronized long getDeltaCount() { return deltaCount; } + synchronized Set getTotalSet() { return set; } + synchronized Set getDeltaSet() { return deltaSet; } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java new file mode 100644 index 000000000000..a9e83f323791 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java @@ -0,0 +1,152 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.AbstractSet; +import java.util.HashMap; +import java.util.Iterator; + +/** + * A CounterSet maintains a set of {@link Counter}s. + * + *

Thread-safe. + */ +public class CounterSet extends AbstractSet> { + + /** Registered counters. */ + private final HashMap> counters = new HashMap<>(); + + private final AddCounterMutator addCounterMutator = new AddCounterMutator(); + + /** + * Constructs a CounterSet containing the given Counters. + */ + public CounterSet(Counter... counters) { + for (Counter counter : counters) { + addNewCounter(counter); + } + } + + /** + * Returns an object that supports adding additional counters into + * this CounterSet. + */ + public AddCounterMutator getAddCounterMutator() { + return addCounterMutator; + } + + /** + * Adds a new counter, throwing an exception if a counter of the + * same name already exists. + */ + public void addNewCounter(Counter counter) { + if (!addCounter(counter)) { + throw new IllegalArgumentException( + "Counter " + counter + " duplicates an existing counter in " + this); + } + } + + /** + * Adds the given Counter to this CounterSet. + * + *

If a counter with the same name already exists, it will be + * reused, as long as it is compatible. + * + * @return the Counter that was reused, or added + * @throws IllegalArgumentException if the a counter with the same + * name but an incompatible kind had already been added + */ + public synchronized Counter addOrReuseCounter(Counter counter) { + Counter oldCounter = counters.get(counter.getName()); + if (oldCounter == null) { + // A new counter. + counters.put(counter.getName(), counter); + return counter; + } + if (counter.isCompatibleWith(oldCounter)) { + // Return the counter to reuse. + @SuppressWarnings("unchecked") + Counter compatibleCounter = (Counter) oldCounter; + return compatibleCounter; + } + throw new IllegalArgumentException( + "Counter " + counter + " duplicates incompatible counter " + + oldCounter + " in " + this); + } + + /** + * Adds a counter. Returns {@code true} if the counter was added to the set + * and false if the given counter was {@code null} or it already existed in + * the set. + * + * @param counter to register + */ + public boolean addCounter(Counter counter) { + return add(counter); + } + + /** + * Returns the Counter with the given name in this CounterSet; + * returns null if no such Counter exists. + */ + public synchronized Counter getExistingCounter(String name) { + return counters.get(name); + } + + @Override + public synchronized Iterator> iterator() { + return counters.values().iterator(); + } + + @Override + public synchronized int size() { + return counters.size(); + } + + @Override + public synchronized boolean add(Counter e) { + if (null == e) { + return false; + } + if (counters.containsKey(e.getName())) { + return false; + } + counters.put(e.getName(), e); + return true; + } + + /** + * A nested class that supports adding additional counters into the + * enclosing CounterSet. This is useful as a mutator; hiding other + * public methods of the CounterSet. + */ + public class AddCounterMutator { + /** + * Adds the given Counter into the enclosing CounterSet. + * + *

If a counter with the same name already exists, it will be + * reused, as long as it has the same type. + * + * @return the Counter that was reused, or added + * @throws IllegalArgumentException if the a counter with the same + * name but an incompatible kind had already been added + */ + public Counter addCounter(Counter counter) { + return addOrReuseCounter(counter); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java new file mode 100644 index 000000000000..447dadcb8ef7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * An interface for things that allow observing the size in bytes of + * encoded values of type {@code T}. + * + * @param the type of the values being observed + */ +public interface ElementByteSizeObservable { + /** + * Returns whether {@link #registerByteSizeObserver} is cheap enough + * to call for every element, that is, if this + * {@code ElementByteSizeObservable} can calculate the byte size of + * the element to be coded in roughly constant time (or lazily). + */ + public boolean isRegisterByteSizeObserverCheap(T value); + + /** + * Notifies the {@code ElementByteSizeObserver} about the byte size + * of the encoded value using this {@code ElementByteSizeObservable}. + */ + public void registerByteSizeObserver(T value, + ElementByteSizeObserver observer) + throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java new file mode 100644 index 000000000000..f8f727090237 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.ArrayList; +import java.util.List; +import java.util.Observer; + +/** + * An abstract class used for iterables that notify observers about size in + * bytes of their elements, as they are being iterated over. + * + * @param the type of elements returned by this iterable + * @param type type of iterator returned by this iterable + */ +public abstract class ElementByteSizeObservableIterable< + V, VI extends ElementByteSizeObservableIterator> + implements Iterable { + private List observers = new ArrayList<>(); + + /** + * Derived classes override this method to return an iterator for this + * iterable. + */ + protected abstract VI createIterator(); + + /** + * Sets the observer, which will observe the iterator returned in + * the next call to iterator() method. Future calls to iterator() + * won't be observed, unless an observer is set again. + */ + public void addObserver(Observer observer) { + observers.add(observer); + } + + /** + * Returns a new iterator for this iterable. If an observer was set in + * a previous call to setObserver(), it will observe the iterator returned. + */ + @Override + public VI iterator() { + VI iterator = createIterator(); + for (Observer observer : observers) { + iterator.addObserver(observer); + } + observers.clear(); + return iterator; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java new file mode 100644 index 000000000000..50c9add0edaa --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Iterator; +import java.util.Observable; + +/** + * An abstract class used for iterators that notify observers about size in + * bytes of their elements, as they are being iterated over. The subclasses + * need to implement the standard Iterator interface and call method + * notifyValueReturned() for each element read and/or iterated over. + * + * @param value type + */ +public abstract class ElementByteSizeObservableIterator + extends Observable implements Iterator { + protected final void notifyValueReturned(long byteSize) { + setChanged(); + notifyObservers(byteSize); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java new file mode 100644 index 000000000000..9cccb4365c6f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java @@ -0,0 +1,84 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Observable; +import java.util.Observer; + +/** + * An observer that gets notified when additional bytes are read + * and/or used. It adds all bytes into a local counter. When the + * observer gets advanced via the next() call, it adds the total byte + * count to the specified counter, and prepares for the next element. + */ +public class ElementByteSizeObserver implements Observer { + private final Counter counter; + private boolean isLazy = false; + private long totalSize = 0; + + public ElementByteSizeObserver(Counter counter) { + this.counter = counter; + } + + /** + * Sets byte counting for the current element as lazy. That is, the + * observer will get notified of the element's byte count only as + * element's pieces are being processed or iterated over. + */ + public void setLazy() { + isLazy = true; + } + + /** + * Returns whether byte counting for the current element is lazy, that is, + * whether the observer gets notified of the element's byte count only as + * element's pieces are being processed or iterated over. + */ + public boolean getIsLazy() { + return isLazy; + } + + /** + * Updates the observer with a context specified, but without an instance of + * the Observable. + */ + public void update(Object obj) { + update(null, obj); + } + + @Override + public void update(Observable obs, Object obj) { + if (obj instanceof Long) { + totalSize += (Long) obj; + } else if (obj instanceof Integer) { + totalSize += (Integer) obj; + } else { + throw new AssertionError("unexpected parameter object"); + } + } + + /** + * Advances the observer to the next element. Adds the current total byte + * size to the counter, and prepares the observer for the next element. + */ + public void advance() { + counter.addValue(totalSize); + + totalSize = 0; + isLazy = false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ForwardingReiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ForwardingReiterator.java new file mode 100644 index 000000000000..f3008232a107 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ForwardingReiterator.java @@ -0,0 +1,83 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * A {@link Reiterator} which forwards to another {@code Reiterator}, useful for + * implementing {@code Reiterator} wrappers. + * + * @param the type of elements returned by this iterator + */ +public abstract class ForwardingReiterator + implements Reiterator, Cloneable { + private Reiterator base; + + /** + * Constructs a {@link ForwardingReiterator}. + * @param base supplies a base reiterator to forward requests to. This + * reiterator will be used directly; it will not be copied by the constructor. + */ + public ForwardingReiterator(Reiterator base) { + this.base = checkNotNull(base); + } + + @Override + protected ForwardingReiterator clone() { + ForwardingReiterator result; + try { + result = (ForwardingReiterator) super.clone(); + } catch (CloneNotSupportedException e) { + throw new AssertionError( + "Object.clone() for a ForwardingReiterator threw " + + "CloneNotSupportedException; this should not happen, " + + "since ForwardingReiterator implements Cloneable.", + e); + } + result.base = base.copy(); + return result; + } + + @Override + public boolean hasNext() { + return base.hasNext(); + } + + @Override + public T next() { + return base.next(); + } + + @Override + public void remove() { + base.remove(); + } + + /** + * {@inheritDoc} + * + *

This implementation uses {@link #clone} to construct a duplicate of the + * {@link Reiterator}. Derived classes must either implement + * {@link Cloneable} semantics, or must provide an alternative implementation + * of this method. + */ + @Override + public ForwardingReiterator copy() { + return clone(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Metric.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Metric.java new file mode 100644 index 000000000000..23a590743b21 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Metric.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * A metric (e.g., CPU usage) that can be reported by a worker. + * + * @param the type of the metric's value + */ +public abstract class Metric { + String name; + T value; + + public Metric(String name, T value) { + this.name = name; + this.value = value; + } + + public String getName() { return name; } + + public T getValue() { return value; } + + /** + * A double-valued Metric. + */ + public static class DoubleMetric extends Metric { + public DoubleMetric(String name, double value) { + super(name, value); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java new file mode 100644 index 000000000000..d139380c65c1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import java.util.NoSuchElementException; + +/** + * A {@link Reiterator} that supports one-element lookahead during iteration. + * + * @param the type of elements returned by this iterator + */ +public final class PeekingReiterator implements Reiterator { + private T nextElement; + private boolean nextElementComputed; + private final Reiterator iterator; + + public PeekingReiterator(Reiterator iterator) { + this.iterator = checkNotNull(iterator); + } + + PeekingReiterator(PeekingReiterator it) { + this.iterator = checkNotNull(it).iterator.copy(); + this.nextElement = it.nextElement; + this.nextElementComputed = it.nextElementComputed; + } + + @Override + public boolean hasNext() { + computeNext(); + return nextElementComputed; + } + + @Override + public T next() { + T result = peek(); + nextElementComputed = false; + return result; + } + + /** + * {@inheritDoc} + * + *

If {@link #peek} is called, {@code remove} is disallowed until + * {@link #next} has been subsequently called. + */ + @Override + public void remove() { + checkState(!nextElementComputed, + "After peek(), remove() is disallowed until next() is called"); + iterator.remove(); + } + + @Override + public PeekingReiterator copy() { + return new PeekingReiterator(this); + } + + /** + * Returns the element that would be returned by {@link #next}, without + * actually consuming the element. + * @throws NoSuchElementException if there is no next element + */ + public T peek() { + computeNext(); + if (!nextElementComputed) { + throw new NoSuchElementException(); + } + return nextElement; + } + + private void computeNext() { + if (nextElementComputed) { + return; + } + if (!iterator.hasNext()) { + return; + } + nextElement = iterator.next(); + nextElementComputed = true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java new file mode 100644 index 000000000000..ebf30459e277 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * An {@link Iterable} that returns {@link Reiterator} iterators. + * + * @param the type of elements returned by the iterator + */ +public interface Reiterable extends Iterable { + @Override + public Reiterator iterator(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java new file mode 100644 index 000000000000..7613a3a37bd3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Iterator; + +/** + * An {@link Iterator} with the ability to copy its iteration state. + * + * @param the type of elements returned by this iterator + */ +public interface Reiterator extends Iterator { + /** + * Returns a copy of the current {@link Reiterator}. The copy's iteration + * state is logically independent of the current iterator; each may be + * advanced without affecting the other. + * + *

The returned {@code Reiterator} is not guaranteed to return + * referentially identical iteration results as the original + * {@link Reiterator}, although {@link Object#equals} will typically return + * true for the corresponding elements of each if the original source is + * logically immutable. + */ + public Reiterator copy(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java new file mode 100644 index 000000000000..0dd2af486ba0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities shared by multiple PipelineRunner implementations. **/ +package com.google.cloud.dataflow.sdk.util.common; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReader.java new file mode 100644 index 000000000000..2a596c0d86f8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReader.java @@ -0,0 +1,148 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.util.common.Reiterator; + +import java.util.ListIterator; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * BatchingShuffleEntryReader provides a mechanism for reading entries from + * a shuffle dataset. + */ +@NotThreadSafe +public final class BatchingShuffleEntryReader implements ShuffleEntryReader { + private final ShuffleBatchReader batchReader; + + /** + * Constructs a {@link BatchingShuffleEntryReader} + * + * @param batchReader supplies the underlying + * {@link ShuffleBatchReader} to read batches of entries from + */ + public BatchingShuffleEntryReader( + ShuffleBatchReader batchReader) { + this.batchReader = checkNotNull(batchReader); + } + + @Override + public Reiterator read( + @Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition) { + return new ShuffleReadIterator(startPosition, endPosition); + } + + /** + * ShuffleReadIterator iterates over a (potentially huge) sequence of shuffle + * entries. + */ + private final class ShuffleReadIterator implements Reiterator { + // Shuffle service returns entries in pages. If the response contains a + // non-null nextStartPosition, we have to ask for more pages. The response + // with null nextStartPosition signifies the end of stream. + @Nullable private final ShufflePosition endPosition; + @Nullable private ShufflePosition nextStartPosition; + + /** The most recently read batch. */ + @Nullable ShuffleBatchReader.Batch currentBatch; + /** An iterator over the most recently read batch. */ + @Nullable private ListIterator entries; + + ShuffleReadIterator(@Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition) { + this.nextStartPosition = startPosition; + this.endPosition = endPosition; + } + + private ShuffleReadIterator(ShuffleReadIterator it) { + this.endPosition = it.endPosition; + this.nextStartPosition = it.nextStartPosition; + this.currentBatch = it.currentBatch; + // The idea here: if the iterator being copied was in the middle of a + // batch (the typical case), create a new iteration state at the same + // point in the same batch. + this.entries = (it.entries == null + ? null + : it.currentBatch.entries.listIterator(it.entries.nextIndex())); + } + + @Override + public boolean hasNext() { + fillEntriesIfNeeded(); + // TODO: Report API errors to the caller using checked + // exceptions. + return entries.hasNext(); + } + + @Override + public ShuffleEntry next() throws NoSuchElementException { + fillEntriesIfNeeded(); + ShuffleEntry entry = entries.next(); + return entry; + } + + @Override + public void remove() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + @Override + public ShuffleReadIterator copy() { + return new ShuffleReadIterator(this); + } + + private void fillEntriesIfNeeded() { + if (entries != null && entries.hasNext()) { + // Has more records in the current page, or error. + return; + } + + if (entries != null && nextStartPosition == null) { + // End of stream. + checkState(!entries.hasNext()); + return; + } + + do { + fillEntries(); + } while (!entries.hasNext() && nextStartPosition != null); + } + + private void fillEntries() { + try { + ShuffleBatchReader.Batch batch = + batchReader.read(nextStartPosition, endPosition); + nextStartPosition = batch.nextStartPosition; + entries = batch.entries.listIterator(); + currentBatch = batch; + } catch (RuntimeException e) { + throw e; + } catch (Throwable t) { + throw new RuntimeException(t); + } + + checkState(entries != null); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReader.java new file mode 100644 index 000000000000..87abf21d4229 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReader.java @@ -0,0 +1,228 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Objects; +import com.google.common.base.Throwables; + +import java.io.IOException; +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.SoftReference; +import java.util.HashMap; + +import javax.annotation.Nullable; + +/** A {@link ShuffleBatchReader} that caches batches as they're read. */ +public final class CachingShuffleBatchReader implements ShuffleBatchReader { + private final ShuffleBatchReader reader; + + // The cache itself is implemented as a HashMap of RangeReadReference values, + // keyed by the start and end positions describing the range of a particular + // request (represented by BatchRange). + // + // The first reader for a particular range builds an AsyncReadResult for the + // result, inserts it into the cache, drops the lock, and then completes the + // read; subsequent readers simply wait for the AsyncReadResult to complete. + // + // Note that overlapping ranges are considered distinct; cached entries for + // one range are not used for any other range, even if doing so would avoid a + // fetch. + // + // So this is not a particularly sophisticated algorithm: a smarter cache + // would be able to use subranges of previous requests to satisfy new + // requests. But in this particular case, we expect that the simple algorithm + // will work well. For a given shuffle source, the splits read by various + // iterators over that source starting from a particular position (which is + // how this class is used in practice) should turn out to be constant, if the + // result returned by the service for a particular [start, end) range are + // consistent. So we're not expecting to see overlapping ranges of entries + // within a cache. + // + // It's also been shown -- by implementing it -- that the more thorough + // algorithm is relatively complex, with numerous edge cases requiring very + // careful thought to get right. It's doable, but non-trivial and hard to + // understand and maintain; without a compelling justification, it's better to + // stick with the simpler implementation. + // + // @VisibleForTesting + final HashMap cache = new HashMap<>(); + + // The queue of references which have been collected by the garbage collector. + // This queue should only be used with references of class RangeReadReference. + private final ReferenceQueue refQueue = new ReferenceQueue<>(); + + /** + * Constructs a new {@link CachingShuffleBatchReader}. + * + * @param reader supplies the downstream {@link ShuffleBatchReader} + * this {@code CachingShuffleBatchReader} will use to issue reads + */ + public CachingShuffleBatchReader(ShuffleBatchReader reader) { + this.reader = checkNotNull(reader); + } + + @Override + public Batch read( + @Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition) throws IOException { + + @Nullable AsyncReadResult waitResult = null; + @Nullable AsyncReadResult runResult = null; + final BatchRange batchRange = new BatchRange(startPosition, endPosition); + + synchronized (cache) { + // Remove any GCd entries. + for (Reference ref = refQueue.poll(); + ref != null; + ref = refQueue.poll()) { + RangeReadReference rangeReadRef = (RangeReadReference) ref; + cache.remove(rangeReadRef.getBatchRange()); + } + + // Find the range reference; note that one might not be in the map, or it + // might contain a null if its target has been GCd. + @Nullable RangeReadReference rangeReadRef = cache.get(batchRange); + + // Get a strong reference to the existing AsyncReadResult for the range, if possible. + if (rangeReadRef != null) { + waitResult = rangeReadRef.get(); + } + + // Create a new AsyncReadResult if one is needed. + if (waitResult == null) { + runResult = new AsyncReadResult(); + waitResult = runResult; + rangeReadRef = null; // Replace the previous RangeReadReference. + } + + // Insert a new RangeReadReference into the map if we don't have a usable + // one (either we weren't able to find one in the map, or we did but it + // was already cleared by the GC). + if (rangeReadRef == null) { + cache.put(batchRange, + new RangeReadReference(batchRange, runResult, refQueue)); + } + } // Drop the cache lock. + + if (runResult != null) { + // This thread created the AsyncReadResult, and is responsible for + // actually performing the read. + try { + Batch result = reader.read(startPosition, endPosition); + runResult.setResult(result); + } catch (RuntimeException | IOException e) { + runResult.setException(e); + synchronized (cache) { + // No reason to continue to cache the fact that there was a problem. + // Note that since this thread holds a strong reference to the + // AsyncReadResult, it won't be GCd, so the soft reference held by the + // cache is guaranteed to still be present. + cache.remove(batchRange); + } + } + } + + return waitResult.getResult(); + } + + /** The key for the entries stored in the batch cache. */ + // @VisibleForTesting + static final class BatchRange { + @Nullable private final ShufflePosition startPosition; + @Nullable private final ShufflePosition endPosition; + + public BatchRange(@Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition) { + this.startPosition = startPosition; + this.endPosition = endPosition; + } + + @Override + public boolean equals(Object o) { + return o == this + || (o instanceof BatchRange + && Objects.equal(((BatchRange) o).startPosition, startPosition) + && Objects.equal(((BatchRange) o).endPosition, endPosition)); + } + + @Override + public int hashCode() { + return Objects.hashCode(startPosition, endPosition); + } + } + + /** Holds an asynchronously batch read result. */ + private static final class AsyncReadResult { + @Nullable private Batch batch = null; + @Nullable private Throwable thrown = null; + + public synchronized void setResult(Batch b) { + batch = b; + notifyAll(); + } + + public synchronized void setException(Throwable t) { + thrown = t; + notifyAll(); + } + + public synchronized Batch getResult() throws IOException { + while (batch == null && thrown == null) { + try { + wait(); + } catch (InterruptedException e) { + throw new RuntimeException("interrupted", e); + } + } + if (thrown != null) { + // N.B. setException can only be called with a RuntimeException or an + // IOException, so propagateIfPossible should always do the throw. + Throwables.propagateIfPossible(thrown, IOException.class); + throw new RuntimeException("unexpected", thrown); + } + return batch; + } + } + + /** + * Maintains a soft reference to an AsyncReadResult. + * + *

This class extends {@link SoftReference} so that when the garbage + * collector collects a batch and adds its reference to the cache's reference + * queue, that reference can be cast back to {@code RangeReadReference}, + * allowing us to identify the reference's position in the cache (and to + * therefore remove it). + */ + // @VisibleForTesting + static final class RangeReadReference extends SoftReference { + private final BatchRange range; + + public RangeReadReference( + BatchRange range, AsyncReadResult result, + ReferenceQueue refQueue) { + super(result, refQueue); + this.range = checkNotNull(range); + } + + public BatchRange getBatchRange() { + return range; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CustomSourceFormat.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CustomSourceFormat.java new file mode 100644 index 000000000000..4fc67d60f3c3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/CustomSourceFormat.java @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * An interface for sources which can perform operations on source specifications, such as + * splitting the source and computing its metadata. See {@code SourceOperationRequest} for details. + */ +public interface CustomSourceFormat { + /** + * Performs an operation on the specification of a source. + * See {@code SourceOperationRequest} for details. + */ + public SourceOperationResponse performSourceOperation(SourceOperationRequest operation) + throws Exception; + + /** + * A representation of an operation on the specification of a source, + * e.g. splitting a source into shards, getting the metadata of a source, + * etc. + * + *

The common worker framework does not interpret instances of + * this interface. But a tool-specific framework can make assumptions + * about the implementation, and so the concrete Source subclasses used + * by a tool-specific framework should match. + */ + public interface SourceOperationRequest { + } + + /** + * A representation of the result of a SourceOperationRequest. + * + *

See the comment on {@link SourceOperationRequest} for how instances of this + * interface are used by the rest of the framework. + */ + public interface SourceOperationResponse { + } + + /** + * A representation of a specification of a source. + * + *

See the comment on {@link SourceOperationRequest} for how instances of this + * interface are used by the rest of the framework. + */ + public interface SourceSpec { + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperation.java new file mode 100644 index 000000000000..6325d1ac5cdb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperation.java @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * A flatten operation. + */ +public class FlattenOperation extends ReceivingOperation { + public FlattenOperation(String operationName, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, receivers, + counterPrefix, addCounterMutator, stateSampler); + } + + /** Invoked by tests. */ + public FlattenOperation(OutputReceiver outputReceiver, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + this("FlattenOperation", new OutputReceiver[]{ outputReceiver }, + counterPrefix, addCounterMutator, stateSampler); + } + + @Override + public void process(Object elem) throws Exception { + try (StateSampler.ScopedState process = + stateSampler.scopedState(processState)) { + checkStarted(); + Receiver receiver = receivers[0]; + if (receiver != null) { + receiver.process(elem); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/GroupingShuffleEntryIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/GroupingShuffleEntryIterator.java new file mode 100644 index 000000000000..19428201f039 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/GroupingShuffleEntryIterator.java @@ -0,0 +1,216 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservableIterable; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservableIterator; +import com.google.cloud.dataflow.sdk.util.common.PeekingReiterator; +import com.google.cloud.dataflow.sdk.util.common.Reiterable; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * An iterator through KeyGroupedShuffleEntries. + */ +public abstract class GroupingShuffleEntryIterator + implements Iterator { + /** The iterator through the underlying shuffle records. */ + private PeekingReiterator shuffleIterator; + + /** + * The key of the most recent KeyGroupedShuffleEntries returned by + * {@link #next}, if any. + * + *

If currentKeyBytes is non-null, then it's the key for the last entry + * returned by {@link #next}, and all incoming entries with that key should + * be skipped over by this iterator (since this iterator is iterating over + * keys, not the individual values associated with a given key). + * + *

If currentKeyBytes is null, and shuffleIterator.hasNext(), then the + * key of shuffleIterator.next() is the key of the next + * KeyGroupedShuffleEntries to return from {@link #next}. + */ + @Nullable private byte[] currentKeyBytes = null; + + /** + * Constructs a GroupingShuffleEntryIterator, given a Reiterator + * over ungrouped ShuffleEntries, assuming the ungrouped + * ShuffleEntries for a given key are consecutive. + */ + public GroupingShuffleEntryIterator( + Reiterator shuffleIterator) { + this.shuffleIterator = + new PeekingReiterator( + new ProgressTrackingReiterator<>( + shuffleIterator, + new ProgressTrackerGroup() { + @Override + protected void report(ShuffleEntry entry) { + notifyElementRead(entry.length()); + } + }.start())); + } + + /** Notifies observers about a new element read. */ + protected abstract void notifyElementRead(long byteSize); + + @Override + public boolean hasNext() { + advanceIteratorToNextKey(); + return shuffleIterator.hasNext(); + } + + @Override + public KeyGroupedShuffleEntries next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ShuffleEntry entry = shuffleIterator.peek(); + currentKeyBytes = entry.getKey(); + return new KeyGroupedShuffleEntries( + entry.getPosition(), + currentKeyBytes, + new ValuesIterable(new ValuesIterator(currentKeyBytes))); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + private void advanceIteratorToNextKey() { + if (currentKeyBytes == null) { + return; + } + while (shuffleIterator.hasNext()) { + ShuffleEntry entry = shuffleIterator.peek(); + if (!Arrays.equals(entry.getKey(), currentKeyBytes)) { + break; + } + shuffleIterator.next(); + } + currentKeyBytes = null; + } + + private static class ValuesIterable + extends ElementByteSizeObservableIterable + implements Reiterable { + private final ValuesIterator base; + + public ValuesIterable(ValuesIterator base) { + this.base = checkNotNull(base); + } + + @Override + public ValuesIterator createIterator() { + return base.copy(); + } + } + + /** + * Provides the {@link Reiterator} used to iterate through the + * shuffle entries of a KeyGroupedShuffleEntries. + */ + private class ValuesIterator + extends ElementByteSizeObservableIterator + implements Reiterator { + // N.B. This class is *not* static; it maintains a reference to its + // enclosing KeyGroupedShuffleEntriesIterator instance so that it can update + // that instance's shuffleIterator as an optimization. + + private final byte[] valueKeyBytes; + private final PeekingReiterator valueShuffleIterator; + private final ProgressTracker tracker; + private boolean nextKnownValid = false; + + public ValuesIterator(byte[] valueKeyBytes) { + this.valueKeyBytes = checkNotNull(valueKeyBytes); + this.valueShuffleIterator = shuffleIterator.copy(); + // N.B. The ProgressTrackerGroup captures the reference to the original + // ValuesIterator for a given values iteration. Which happens to be + // exactly what we want, since this is also the ValuesIterator whose + // base Observable has the references to all of the Observers watching + // the iteration. Copied ValuesIterator instances do *not* have these + // Observers, but that's fine, since the derived ProgressTracker + // instances reference the ProgressTrackerGroup which references the + // original ValuesIterator which does have them. + this.tracker = new ProgressTrackerGroup() { + @Override + protected void report(ShuffleEntry entry) { + notifyValueReturned(entry.length()); + } + }.start(); + } + + private ValuesIterator(ValuesIterator it) { + this.valueKeyBytes = it.valueKeyBytes; + this.valueShuffleIterator = it.valueShuffleIterator.copy(); + this.tracker = it.tracker.copy(); + this.nextKnownValid = it.nextKnownValid; + } + + @Override + public boolean hasNext() { + if (nextKnownValid) { + return true; + } + if (!valueShuffleIterator.hasNext()) { + return false; + } + ShuffleEntry entry = valueShuffleIterator.peek(); + nextKnownValid = Arrays.equals(entry.getKey(), valueKeyBytes); + + // Opportunistically update the parent KeyGroupedShuffleEntriesIterator, + // potentially allowing it to skip a large number of key/value pairs + // with this key. + if (!nextKnownValid && valueKeyBytes == currentKeyBytes) { + shuffleIterator = valueShuffleIterator.copy(); + currentKeyBytes = null; + } + + return nextKnownValid; + } + + @Override + public ShuffleEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ShuffleEntry entry = valueShuffleIterator.next(); + nextKnownValid = false; + tracker.saw(entry); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public ValuesIterator copy() { + return new ValuesIterator(this); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/KeyGroupedShuffleEntries.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/KeyGroupedShuffleEntries.java new file mode 100644 index 000000000000..1b8b552b521e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/KeyGroupedShuffleEntries.java @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.Reiterable; + +/** + * A collection of ShuffleEntries, all with the same key. + */ +public class KeyGroupedShuffleEntries { + public final byte[] position; + public final byte[] key; + public final Reiterable values; + + public KeyGroupedShuffleEntries(byte[] position, byte[] key, + Reiterable values) { + this.position = position; + this.key = key; + this.values = values; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutor.java new file mode 100644 index 000000000000..45d5e8c6715e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutor.java @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.ListIterator; + +/** + * An executor for a map task, defined by a list of Operations. + */ +public class MapTaskExecutor extends WorkExecutor { + private static final Logger LOG = + LoggerFactory.getLogger(MapTaskExecutor.class); + + /** The operations in the map task, in execution order. */ + public final List operations; + + /** The StateSampler for tracking where time is being spent, or null. */ + protected final StateSampler stateSampler; + + /** + * Creates a new MapTaskExecutor. + * + * @param operations the operations of the map task, in order of execution + * @param counters a set of system counters associated with + * operations, which may get extended during execution + * @param stateSampler a state sampler for tracking where time is being spent + */ + public MapTaskExecutor(List operations, + CounterSet counters, + StateSampler stateSampler) { + super(counters); + this.operations = operations; + this.stateSampler = stateSampler; + } + + @Override + public void execute() throws Exception { + LOG.debug("executing map task"); + + // Start operations, in reverse-execution-order, so that a + // consumer is started before a producer might output to it. + // Starting a root operation such as a ReadOperation does the work + // of processing the input dataset. + LOG.debug("starting operations"); + ListIterator iterator = + operations.listIterator(operations.size()); + while (iterator.hasPrevious()) { + Operation op = iterator.previous(); + op.start(); + } + + // Finish operations, in forward-execution-order, so that a + // producer finishes outputting to its consumers before those + // consumers are themselves finished. + LOG.debug("finishing operations"); + for (Operation op : operations) { + op.finish(); + } + + LOG.debug("map task execution complete"); + + // TODO: support for success / failure ports? + } + + @Override + public Source.Progress getWorkerProgress() throws Exception { + return getReadOperation().getProgress(); + } + + @Override + public Source.Position proposeStopPosition( + Source.Progress proposedStopPosition) throws Exception { + return getReadOperation().proposeStopPosition(proposedStopPosition); + } + + ReadOperation getReadOperation() throws Exception { + if (operations == null || operations.isEmpty()) { + throw new IllegalStateException( + "Map task has no operation."); + } + + Operation readOperation = operations.get(0); + if (!(readOperation instanceof ReadOperation)) { + throw new IllegalStateException( + "First operation in the map task is not a ReadOperation."); + } + + return (ReadOperation) readOperation; + } + + @Override + public void close() throws Exception { + stateSampler.close(); + super.close(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Operation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Operation.java new file mode 100644 index 000000000000..bedc081cec99 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Operation.java @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * The abstract base class for Operations, which correspond to + * Instructions in the original MapTask InstructionGraph. + * + * Call start() to start the operation. + * + * A read operation's start() method actually reads the data, and in + * effect runs the pipeline. + * + * Call finish() to finish the operation. + * + * Since both start() and finish() may call process() on + * this operation's consumers, start an operation after + * starting its consumers, and finish an operation before + * finishing its consumers. + */ +public abstract class Operation { + /** + * The array of consuming receivers, one per operation output + * "port" (e.g., DoFn main or side output). A receiver might be + * null if that output isn't being consumed. + */ + public final OutputReceiver[] receivers; + + /** + * The possible initialization states of an Operation. + * For internal self-checking purposes. + */ + public enum InitializationState { + // start() hasn't yet been called. + UNSTARTED, + + // start() has been called, but finish() hasn't yet been called. + STARTED, + + // finish() has been called. + FINISHED + } + + /** The initialization state of this Operation. */ + public InitializationState initializationState = + InitializationState.UNSTARTED; + + protected final StateSampler stateSampler; + + protected final int startState; + protected final int processState; + protected final int finishState; + + public Operation(String operationName, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + this.receivers = receivers; + this.stateSampler = stateSampler; + startState = stateSampler.stateForName(operationName + "-start"); + processState = stateSampler.stateForName(operationName + "-process"); + finishState = stateSampler.stateForName(operationName + "-finish"); + } + + /** + * Checks that this oepration is not yet started, throwing an + * exception otherwise. + */ + void checkUnstarted() { + if (initializationState != InitializationState.UNSTARTED) { + throw new AssertionError( + "expecting this instruction to not yet be started"); + } + } + + /** + * Checks that this oepration has been started but not yet finished, + * throwing an exception otherwise. + */ + void checkStarted() { + if (initializationState != InitializationState.STARTED) { + throw new AssertionError( + "expecting this instruction to be started"); + } + } + + /** + * Checks that this oepration has been finished, throwing an + * exception otherwise. + */ + void checkFinished() { + if (initializationState != InitializationState.FINISHED) { + throw new AssertionError( + "expecting this instruction to be finished"); + } + } + + /** + * Starts this Operation's execution. Called after all successsor + * consuming operations have been started. + */ + public void start() throws Exception { + checkUnstarted(); + initializationState = InitializationState.STARTED; + } + + /** + * Finishes this Operation's execution. Called after all + * predecessor producing operations have been finished. + */ + public void finish() throws Exception { + checkStarted(); + initializationState = InitializationState.FINISHED; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiver.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiver.java new file mode 100644 index 000000000000..a13b74afbf8b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiver.java @@ -0,0 +1,207 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservable; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Receiver that forwards each input it receives to each of a list of + * output Receivers. Additionally, it tracks output counters, that is, size + * information for elements passing through. + */ +public class OutputReceiver implements Receiver { + private final String outputName; + // Might be null, e.g., undeclared outputs will not have an + // elementByteSizeObservable. + private final ElementByteSizeObservable elementByteSizeObservable; + private final Counter elementCount; + private Counter byteCount = null; + private Counter meanByteCount = null; + private ElementByteSizeObserver byteCountObserver = null; + private ElementByteSizeObserver meanByteCountObserver = null; + private final List outputs = new ArrayList<>(); + private final Random randomGenerator = new Random(); + private int samplingToken = 0; + private final int samplingTokenUpperBound = 1000000; // Lowest sampling probability: 0.001%. + private final int samplingCutoff = 10; + + public OutputReceiver(String outputName, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator) { + this(outputName, (ElementByteSizeObservable) null, + counterPrefix, addCounterMutator); + } + + public OutputReceiver(String outputName, + ElementByteSizeObservable elementByteSizeObservable, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator) { + this.outputName = outputName; + this.elementByteSizeObservable = elementByteSizeObservable; + + elementCount = addCounterMutator.addCounter( + Counter.longs(elementsCounterName(counterPrefix, outputName), SUM)); + + if (elementByteSizeObservable != null) { + String bytesCounterName = bytesCounterName(counterPrefix, outputName); + if (bytesCounterName != null) { + byteCount = addCounterMutator.addCounter( + Counter.longs(bytesCounterName, SUM)); + byteCountObserver = new ElementByteSizeObserver(byteCount); + } + String meanBytesCounterName = + meanBytesCounterName(counterPrefix, outputName); + if (meanBytesCounterName != null) { + meanByteCount = addCounterMutator.addCounter( + Counter.longs(meanBytesCounterName, MEAN)); + meanByteCountObserver = new ElementByteSizeObserver(meanByteCount); + } + } + } + + protected String elementsCounterName(String counterPrefix, + String outputName) { + return outputName + "-ElementCount"; + } + protected String bytesCounterName(String counterPrefix, + String outputName) { + return null; + } + protected String meanBytesCounterName(String counterPrefix, + String outputName) { + return outputName + "-MeanByteCount"; + } + + /** + * Adds a new receiver that this OutputReceiver forwards to. + */ + public void addOutput(Receiver receiver) { + outputs.add(receiver); + } + + @Override + public void process(Object elem) throws Exception { + // Increment element counter. + elementCount.addValue(1L); + + // Increment byte counter. + boolean advanceByteCountObserver = false; + boolean advanceMeanByteCountObserver = false; + if ((byteCountObserver != null || meanByteCountObserver != null) + && (sampleElement() + || elementByteSizeObservable.isRegisterByteSizeObserverCheap( + elem))) { + + if (byteCountObserver != null) { + elementByteSizeObservable.registerByteSizeObserver( + elem, byteCountObserver); + } + if (meanByteCountObserver != null) { + elementByteSizeObservable.registerByteSizeObserver( + elem, meanByteCountObserver); + } + + if (byteCountObserver != null) { + if (!byteCountObserver.getIsLazy()) { + byteCountObserver.advance(); + } else { + advanceByteCountObserver = true; + } + } + if (meanByteCountObserver != null) { + if (!meanByteCountObserver.getIsLazy()) { + meanByteCountObserver.advance(); + } else { + advanceMeanByteCountObserver = true; + } + } + } + + // Fan-out. + for (Receiver out : outputs) { + if (out != null) { + out.process(elem); + } + } + + // Advance lazy ElementByteSizeObservers, if any. + // Note that user's code is allowed to store the element of one + // DoFn.processElement() call and access it later on. We are still + // calling next() here, causing an update to byteCount. If user's + // code really accesses more element's pieces later on, their byte + // count would accrue against a future element. This is not ideal, + // but still approximately correct. + if (advanceByteCountObserver) { + byteCountObserver.advance(); + } + if (advanceMeanByteCountObserver) { + meanByteCountObserver.advance(); + } + } + + public String getName() { + return outputName; + } + + public Counter getElementCount() { + return elementCount; + } + + public Counter getByteCount() { + return byteCount; + } + + public Counter getMeanByteCount() { + return meanByteCount; + } + + protected boolean sampleElement() { + // Sampling probability decreases as the element count is increasing. + // We unconditionally sample the first samplingCutoff elements. For the + // next samplingCutoff elements, the sampling probability drops from 100% + // to 50%. The probability of sampling the Nth element is: + // min(1, samplingCutoff / N), with an additional lower bound of + // samplingCutoff / samplingTokenUpperBound. This algorithm may be refined + // later. + samplingToken = Math.min(samplingToken + 1, samplingTokenUpperBound); + return randomGenerator.nextInt(samplingToken) < samplingCutoff; + } + + /** Invoked by tests only. */ + public int getReceiverCount() { + return outputs.size(); + } + + /** Invoked by tests only. */ + public Receiver getOnlyReceiver() { + if (outputs.size() != 1) { + throw new AssertionError("only one receiver expected"); + } + + return outputs.get(0); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoFn.java new file mode 100644 index 000000000000..b922acc412d4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoFn.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * Abstract base class for ParDoFns, invocable by ParDoOperations. + */ +public abstract class ParDoFn { + public abstract void startBundle(Receiver... receivers) throws Exception; + + public abstract void processElement(Object elem) throws Exception; + + public abstract void finishBundle() throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperation.java new file mode 100644 index 000000000000..7a620983476f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperation.java @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * A ParDo mapping function. + */ +public class ParDoOperation extends ReceivingOperation { + public final ParDoFn fn; + + public ParDoOperation(String operationName, + ParDoFn fn, + OutputReceiver[] outputReceivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, outputReceivers, + counterPrefix, addCounterMutator, stateSampler); + this.fn = fn; + } + + @Override + public void start() throws Exception { + try (StateSampler.ScopedState start = + stateSampler.scopedState(startState)) { + super.start(); + fn.startBundle(receivers); + } + } + + @Override + public void process(Object elem) throws Exception { + try (StateSampler.ScopedState process = + stateSampler.scopedState(processState)) { + checkStarted(); + fn.processElement(elem); + } + } + + @Override + public void finish() throws Exception { + try (StateSampler.ScopedState finish = + stateSampler.scopedState(finishState)) { + checkStarted(); + fn.finishBundle(); + super.finish(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperation.java new file mode 100644 index 000000000000..a4afa5b2820d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperation.java @@ -0,0 +1,521 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * A partial group-by-key operation. + */ +public class PartialGroupByKeyOperation extends ReceivingOperation { + /** + * Provides client-specific operations for grouping keys. + */ + public static interface GroupingKeyCreator { + public Object createGroupingKey(K key) throws Exception; + } + + /** + * Provides client-specific operations for size estimates. + */ + public static interface SizeEstimator { + public long estimateSize(E element) throws Exception; + } + + /** + * Provides client-specific operations for working with elements + * that are key/value or key/values pairs. + */ + public interface PairInfo { + public Object getKeyFromInputPair(Object pair); + public Object getValueFromInputPair(Object pair); + public Object makeOutputPair(Object key, Object value); + } + + /** + * Provides client-specific operations for combining values. + */ + public interface Combiner { + public VA createAccumulator(K key); + public VA add(K key, VA accumulator, VI value); + public VA merge(K key, Iterable accumulators); + public VO extract(K key, VA accumulator); + } + + /** + * A wrapper around a byte[] that uses structural, value-based + * equality rather than byte[]'s normal object identity. + */ + public static class StructuralByteArray { + byte[] value; + + public StructuralByteArray(byte[] value) { + this.value = value; + } + + public byte[] getValue() { return value; } + + @Override + public boolean equals(Object o) { + if (o instanceof StructuralByteArray) { + StructuralByteArray that = (StructuralByteArray) o; + return Arrays.equals(this.value, that.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "Val" + Arrays.toString(value); + } + } + + // By default, how many bytes we allow the grouping table to consume before + // it has to be flushed. + static final long DEFAULT_MAX_GROUPING_TABLE_BYTES = 100_000_000L; + + // How many bytes a word in the JVM has. + static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord(); + + /** + * The number of bytes of overhead to store an entry in the + * grouping table (a {@code HashMap}), + * ignoring the actual number of bytes in the keys and values: + * + * - an array element (1 word), + * - a HashMap.Entry (4 words), + * - a StructuralByteArray (1 words), + * - a backing array (guessed at 1 word for the length), + * - a KeyAndValues (2 words), + * - an ArrayList (2 words), + * - a backing array (1 word), + * - per-object overhead (JVM-specific, guessed at 2 words * 6 objects). + */ + static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD; + + final GroupingTable groupingTable; + + @SuppressWarnings("unchecked") + public PartialGroupByKeyOperation( + String operationName, + GroupingKeyCreator groupingKeyCreator, + SizeEstimator keySizeEstimator, SizeEstimator valueSizeEstimator, + PairInfo pairInfo, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, receivers, counterPrefix, addCounterMutator, stateSampler); + groupingTable = new BufferingGroupingTable( + DEFAULT_MAX_GROUPING_TABLE_BYTES, groupingKeyCreator, + pairInfo, keySizeEstimator, valueSizeEstimator); + } + + @SuppressWarnings("unchecked") + public PartialGroupByKeyOperation( + String operationName, + GroupingKeyCreator groupingKeyCreator, + SizeEstimator keySizeEstimator, SizeEstimator valueSizeEstimator, + double sizeEstimatorSampleRate, + PairInfo pairInfo, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, StateSampler stateSampler) { + this(operationName, groupingKeyCreator, + new SamplingSizeEstimator(keySizeEstimator, sizeEstimatorSampleRate, 1.0), + new SamplingSizeEstimator(valueSizeEstimator, sizeEstimatorSampleRate, 1.0), + pairInfo, receivers, counterPrefix, addCounterMutator, stateSampler); + } + + /** Invoked by tests. */ + public PartialGroupByKeyOperation( + GroupingKeyCreator groupingKeyCreator, + SizeEstimator keySizeEstimator, SizeEstimator valueSizeEstimator, + PairInfo pairInfo, + OutputReceiver outputReceiver, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + this("PartialGroupByKeyOperation", groupingKeyCreator, + keySizeEstimator, valueSizeEstimator, pairInfo, + new OutputReceiver[]{ outputReceiver }, + counterPrefix, + addCounterMutator, + stateSampler); + } + + @Override + public void process(Object elem) throws Exception { + try (StateSampler.ScopedState process = + stateSampler.scopedState(processState)) { + if (receivers[0] != null) { + groupingTable.put(elem, receivers[0]); + } + } + } + + @Override + public void finish() throws Exception { + try (StateSampler.ScopedState finish = + stateSampler.scopedState(finishState)) { + checkStarted(); + if (receivers[0] != null) { + groupingTable.flush(receivers[0]); + } + super.finish(); + } + } + + /** + * Sets the maximum amount of memory the grouping table is allowed to + * consume before it has to be flushed. + */ + // @VisibleForTesting + public void setMaxGroupingTableBytes(long maxSize) { + groupingTable.maxSize = maxSize; + } + + /** + * Returns the amount of memory the grouping table currently consumes. + */ + // @VisibleForTesting + public long getGroupingTableBytes() { + return groupingTable.size; + } + + /** + * Returns the number of bytes in a JVM word. In case we failed to + * find the answer, returns 8. + */ + static int getBytesPerJvmWord() { + String wordSizeInBits = System.getProperty("sun.arch.data.model"); + try { + return Integer.parseInt(wordSizeInBits) / 8; + } catch (NumberFormatException e) { + // The JVM word size is unknown. Assume 64-bit. + return 8; + } + } + + private abstract static class GroupingTable { + + // Keep the table relatively full to increase the chance of collisions. + private static final double TARGET_LOAD = 0.9; + + private long maxSize; + private final GroupingKeyCreator groupingKeyCreator; + private final PairInfo pairInfo; + + private long size = 0; + private Map> table; + + public GroupingTable(long maxSize, + GroupingKeyCreator groupingKeyCreator, + PairInfo pairInfo) { + this.maxSize = maxSize; + this.groupingKeyCreator = groupingKeyCreator; + this.pairInfo = pairInfo; + this.table = new HashMap<>(); + } + + interface GroupingTableEntry { + public K getKey(); + public VA getValue(); + public void add(VI value) throws Exception; + public long getSize(); + } + + public abstract GroupingTableEntry createTableEntry(K key) throws Exception; + + /** + * Adds a pair to this table, possibly flushing some entries to output + * if the table is full. + */ + @SuppressWarnings("unchecked") + public void put(Object pair, Receiver receiver) throws Exception { + put((K) pairInfo.getKeyFromInputPair(pair), + (VI) pairInfo.getValueFromInputPair(pair), + receiver); + } + + /** + * Adds the key and value to this table, possibly flushing some entries + * to output if the table is full. + */ + public void put(K key, VI value, Receiver receiver) throws Exception { + Object groupingKey = groupingKeyCreator.createGroupingKey(key); + GroupingTableEntry entry = table.get(groupingKey); + if (entry == null) { + entry = createTableEntry(key); + table.put(groupingKey, entry); + size += PER_KEY_OVERHEAD; + } else { + size -= entry.getSize(); + } + entry.add(value); + size += entry.getSize(); + + if (size >= maxSize) { + long targetSize = (long) (TARGET_LOAD * maxSize); + Iterator> entries = + table.values().iterator(); + while (size >= targetSize) { + if (!entries.hasNext()) { + // Should never happen, but sizes may be estimates... + size = 0; + break; + } + GroupingTableEntry toFlush = entries.next(); + entries.remove(); + size -= toFlush.getSize() + PER_KEY_OVERHEAD; + output(toFlush, receiver); + } + } + } + + /** + * Output the given entry. Does not actually remove it from the table or + * update this table's size. + */ + private void output(GroupingTableEntry entry, Receiver receiver) throws Exception { + receiver.process(pairInfo.makeOutputPair(entry.getKey(), entry.getValue())); + } + + /** + * Flushes all entries in this table to output. + */ + public void flush(Receiver output) throws Exception { + for (GroupingTableEntry entry : table.values()) { + output(entry, output); + } + table.clear(); + size = 0; + } + + } + + /** + * A grouping table that simply buffers all inserted values in a list. + */ + public static class BufferingGroupingTable extends GroupingTable> { + + public final SizeEstimator keySizer; + public final SizeEstimator valueSizer; + + public BufferingGroupingTable(long maxSize, + GroupingKeyCreator groupingKeyCreator, + PairInfo pairInfo, + SizeEstimator keySizer, + SizeEstimator valueSizer) { + super(maxSize, groupingKeyCreator, pairInfo); + this.keySizer = keySizer; + this.valueSizer = valueSizer; + } + + @Override + public GroupingTableEntry> createTableEntry(final K key) throws Exception { + return new GroupingTableEntry>() { + long size = keySizer.estimateSize(key); + final List values = new ArrayList<>(); + public K getKey() { return key; } + public List getValue() { return values; } + public long getSize() { return size; } + public void add(V value) throws Exception { + values.add(value); + size += BYTES_PER_JVM_WORD + valueSizer.estimateSize(value); + } + }; + } + } + + /** + * A grouping table that uses the given combiner to combine values in place. + */ + public static class CombiningGroupingTable extends GroupingTable { + + private final Combiner combiner; + private final SizeEstimator keySizer; + private final SizeEstimator valueSizer; + + public CombiningGroupingTable(long maxSize, + GroupingKeyCreator groupingKeyCreator, + PairInfo pairInfo, + Combiner combineFn, + SizeEstimator keySizer, + SizeEstimator valueSizer) { + super(maxSize, groupingKeyCreator, pairInfo); + this.combiner = combineFn; + this.keySizer = keySizer; + this.valueSizer = valueSizer; + } + + @Override + public GroupingTableEntry createTableEntry(final K key) throws Exception { + return new GroupingTableEntry() { + final long keySize = keySizer.estimateSize(key); + VA accumulator = combiner.createAccumulator(key); + long accumulatorSize = 0; // never used before a value is added... + public K getKey() { return key; } + public VA getValue() { return accumulator; } + public long getSize() { return keySize + accumulatorSize; } + public void add(VI value) throws Exception { + accumulator = combiner.add(key, accumulator, value); + accumulatorSize = valueSizer.estimateSize(accumulator); + } + }; + } + } + + + //////////////////////////////////////////////////////////////////////////// + // Size sampling. + + /** + * Implements size estimation by adaptively delegating to an underlying + * (potentially more expensive) estimator for some elements and returning + * the average value for others. + */ + public static class SamplingSizeEstimator implements SizeEstimator { + + /** + * The degree of confidence required in our expected value predictions + * before we allow under-sampling. + * + *

The value of 3.0 is a confidence interval of about 99.7% for a + * a high-degree-of-freedom t-distribution. + */ + public static final double CONFIDENCE_INTERVAL_SIGMA = 3; + + /** + * The desired size of our confidence interval (relative to the measured + * expected value). + * + *

The value of 0.25 is plus or minus 25%. + */ + public static final double CONFIDENCE_INTERVAL_SIZE = 0.25; + + /** + * Default number of elements that must be measured before elements are skipped. + */ + public static final long DEFAULT_MIN_SAMPLED = 20; + + private final SizeEstimator underlying; + private final double minSampleRate; + private final double maxSampleRate; + private final long minSampled; + private final Random random; + + private long totalElements = 0; + private long sampledElements = 0; + private long sampledSum = 0; + private double sampledSumSquares = 0; + private long estimate; + + private long nextSample = 0; + + public SamplingSizeEstimator( + SizeEstimator underlying, + double minSampleRate, + double maxSampleRate) { + this(underlying, minSampleRate, maxSampleRate, DEFAULT_MIN_SAMPLED, new Random()); + } + + public SamplingSizeEstimator(SizeEstimator underlying, + double minSampleRate, + double maxSampleRate, + long minSampled, + Random random) { + this.underlying = underlying; + this.minSampleRate = minSampleRate; + this.maxSampleRate = maxSampleRate; + this.minSampled = minSampled; + this.random = random; + } + + @Override + public long estimateSize(E element) throws Exception { + if (sampleNow()) { + return recordSample(underlying.estimateSize(element)); + } else { + return estimate; + } + } + + private boolean sampleNow() { + totalElements++; + return --nextSample < 0; + } + + private long recordSample(long value) { + sampledElements += 1; + sampledSum += value; + sampledSumSquares += value * value; + estimate = (long) Math.ceil(sampledSum / sampledElements); + long target = desiredSampleSize(); + if (sampledElements < minSampled || sampledElements < target) { + // Sample immediately. + nextSample = 0; + } else { + double rate = cap( + minSampleRate, + maxSampleRate, + Math.max(1.0 / (totalElements - minSampled + 1), // slowly ramp down + target / (double) totalElements)); // "future" target + // Uses the geometric distribution to return the likely distance between + // successive independent trials of a fixed probability p. This gives the + // same uniform distribution of branching on Math.random() < p, but with + // one random number generation per success rather than one per test, + // which can be a significant savings if p is small. + nextSample = rate == 1.0 + ? 0 + : (long) Math.floor(Math.log(random.nextDouble()) / Math.log(1 - rate)); + } + return value; + } + + private static final double cap(double min, double max, double value) { + return Math.min(max, Math.max(min, value)); + } + + private long desiredSampleSize() { + // We have no a-priori information on the actual distribution of data + // sizes, so compute our desired sample as if it were normal. + // Yes this formula is unstable for small stddev, but we only care about large stddev. + double mean = sampledSum / (double) sampledElements; + double sumSquareDiff = + (sampledSumSquares - (2 * mean * sampledSum) + (sampledElements * mean * mean)); + double stddev = Math.sqrt(sumSquareDiff / (sampledElements - 1)); + double sqrtDesiredSamples = + (CONFIDENCE_INTERVAL_SIGMA * stddev) / (CONFIDENCE_INTERVAL_SIZE * mean); + return (long) Math.ceil(sqrtDesiredSamples * sqrtDesiredSamples); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTracker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTracker.java new file mode 100644 index 000000000000..fd26caa31e69 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTracker.java @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * Provides an interface to an object capable of tracking progress through a + * collection of elements to be processed. + * + * @param the type of elements being tracked + */ +public interface ProgressTracker { + /** + * Copies this {@link ProgressTracker}. The copied tracker will maintain its + * own independent notion of the caller's progress through the collection of + * elements being processed. + */ + public ProgressTracker copy(); + + /** + * Reports an element to this {@link ProgressTracker}, as the element is about + * to be processed. + */ + public void saw(T element); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackerGroup.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackerGroup.java new file mode 100644 index 000000000000..7ed370f16770 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackerGroup.java @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * Implements a group of linked + * {@link ProgressTracker ProgressTrackers} which + * collectively track how far a processing loop has gotten through the elements + * it's processing. Individual {@code ProgressTracker} instances may be copied, + * capturing an independent view of the progress of the system; this turns out + * to be useful for some non-trivial processing loops. The furthest point + * reached by any {@code ProgressTracker} is the one reported. + * + *

This class is abstract. Its single extension point is {@link #report}, + * which should be overriden to provide a function which handles the reporting + * of the supplied element, as appropriate. + * + * @param the type of elements being tracked + */ +public abstract class ProgressTrackerGroup { + // TODO: Instead of an abstract class, strongly consider adding an + // interface like Receiver to the SDK, so that this class can be final and all + // that good stuff. + private long nextIndexToReport = 0; + + public ProgressTrackerGroup() {} + + public final ProgressTracker start() { + return new Tracker(0); + } + + /** Reports the indicated element. */ + protected abstract void report(T element); + + private final class Tracker implements ProgressTracker { + private long nextElementIndex; + + private Tracker(long nextElementIndex) { + this.nextElementIndex = nextElementIndex; + } + + @Override + public ProgressTracker copy() { + return new Tracker(nextElementIndex); + } + + @Override + public void saw(T element) { + long thisElementIndex = nextElementIndex; + nextElementIndex++; + if (thisElementIndex == nextIndexToReport) { + nextIndexToReport = nextElementIndex; + report(element); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackingReiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackingReiterator.java new file mode 100644 index 000000000000..8d5d43fa7488 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ProgressTrackingReiterator.java @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.util.common.ForwardingReiterator; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; + +/** + * Implements a {@link Reiterator} which uses a + * {@link ProgressTrackerGroup.Tracker ProgressTracker} to track how far + * it's gotten through some base {@code Reiterator}. + * {@link ProgressTrackingReiterator#copy} copies the {@code ProgressTracker}, + * allowing for an independent progress state. + * + * @param the type of the elements of this iterator + */ +public final class ProgressTrackingReiterator + extends ForwardingReiterator { + private ProgressTracker tracker; + + public ProgressTrackingReiterator(Reiterator base, + ProgressTracker tracker) { + super(base); + this.tracker = checkNotNull(tracker); + } + + @Override + public T next() { + T result = super.next(); + tracker.saw(result); + return result; + } + + @Override + protected ProgressTrackingReiterator clone() { + ProgressTrackingReiterator result = + (ProgressTrackingReiterator) super.clone(); + result.tracker = tracker.copy(); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperation.java new file mode 100644 index 000000000000..1930e0e61aaa --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperation.java @@ -0,0 +1,233 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.common.base.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Observable; +import java.util.Observer; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A read operation. + * + * Its start() method iterates through all elements of the source + * and emits them on its output. + */ +public class ReadOperation extends Operation { + private static final Logger LOG = LoggerFactory.getLogger(ReadOperation.class); + private static final long DEFAULT_PROGRESS_UPDATE_PERIOD_MS = TimeUnit.SECONDS.toMillis(1); + + /** The Source this operation reads from. */ + public final Source source; + + /** The total byte counter for all data read by this operation. */ + final Counter byteCount; + + /** StateSampler state for advancing the SourceIterator. */ + private final int readState; + + /** + * The Source's reader this operation reads from, created by start(). + * Guarded by sourceIteratorLock. + */ + volatile Source.SourceIterator sourceIterator = null; + private final Object sourceIteratorLock = new Object(); + + /** + * A cache of sourceIterator.getProgress() updated inside the read loop at a bounded rate. + *

+ * Necessary so that ReadOperation.getProgress() can return immediately, rather than potentially + * wait for a read to complete (which can take an unbounded time, delay a worker progress update, + * and cause lease expiration and all sorts of trouble). + */ + private AtomicReference progress = new AtomicReference<>(); + + /** + * On every iteration of the read loop, "progress" is fetched from sourceIterator if requested. + */ + private long progressUpdatePeriodMs = DEFAULT_PROGRESS_UPDATE_PERIOD_MS; + + /** + * Signals whether the next iteration of the read loop should update the progress. + * Set to true every progressUpdatePeriodMs. + */ + private AtomicBoolean isProgressUpdateRequested = new AtomicBoolean(true); + + + public ReadOperation(String operationName, Source source, OutputReceiver[] receivers, + String counterPrefix, CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, receivers, counterPrefix, addCounterMutator, stateSampler); + this.source = source; + this.byteCount = addCounterMutator.addCounter( + Counter.longs(bytesCounterName(counterPrefix, operationName), SUM)); + readState = stateSampler.stateForName(operationName + "-read"); + } + + /** Invoked by tests. */ + ReadOperation(Source source, OutputReceiver outputReceiver, String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, StateSampler stateSampler) { + this("ReadOperation", source, new OutputReceiver[] {outputReceiver}, counterPrefix, + addCounterMutator, stateSampler); + } + + /** + * Invoked by tests. A value of 0 means "update progress on each iteration". + */ + void setProgressUpdatePeriodMs(long millis) { + Preconditions.checkArgument(millis >= 0, "Progress update period must be non-negative"); + progressUpdatePeriodMs = millis; + } + + protected String bytesCounterName(String counterPrefix, String operationName) { + return operationName + "-ByteCount"; + } + + public Source getSource() { + return source; + } + + @Override + public void start() throws Exception { + try (StateSampler.ScopedState start = stateSampler.scopedState(startState)) { + super.start(); + runReadLoop(); + } + } + + protected void runReadLoop() throws Exception { + Receiver receiver = receivers[0]; + if (receiver == null) { + // No consumer of this data; don't do anything. + return; + } + + source.addObserver(new SourceObserver()); + + try (StateSampler.ScopedState process = stateSampler.scopedState(processState)) { + synchronized (sourceIteratorLock) { + sourceIterator = source.iterator(); + } + + // TODO: Consider using the ExecutorService from PipelineOptions instead. + Thread updateRequester = new Thread() { + @Override + public void run() { + while (true) { + isProgressUpdateRequested.set(true); + try { + Thread.sleep(progressUpdatePeriodMs); + } catch (InterruptedException e) { + break; + } + } + } + }; + if (progressUpdatePeriodMs != 0) { + updateRequester.start(); + } + + try { + // Force a progress update at the beginning and at the end. + synchronized (sourceIteratorLock) { + progress.set(sourceIterator.getProgress()); + } + while (true) { + Object value; + // Stop position update request comes concurrently. + // Accesses to iterator need to be synchronized. + try (StateSampler.ScopedState read = stateSampler.scopedState(readState)) { + synchronized (sourceIteratorLock) { + if (!sourceIterator.hasNext()) { + break; + } + value = sourceIterator.next(); + + if (isProgressUpdateRequested.getAndSet(false) || progressUpdatePeriodMs == 0) { + progress.set(sourceIterator.getProgress()); + } + } + } + receiver.process(value); + } + synchronized (sourceIteratorLock) { + progress.set(sourceIterator.getProgress()); + } + } finally { + synchronized (sourceIteratorLock) { + sourceIterator.close(); + } + if (progressUpdatePeriodMs != 0) { + updateRequester.interrupt(); + updateRequester.join(); + } + } + } + } + + /** + * Returns a (possibly slightly stale) value of the progress of the task. + * Guaranteed to not block indefinitely. + * + * @return the task progress, or {@code null} if the source iterator has not + * been initialized + */ + public Source.Progress getProgress() { + return progress.get(); + } + + /** + * Relays the request to update the stop position to {@code SourceIterator}. + * + * @param proposedStopPosition the proposed stop position + * @return the new stop position updated in {@code SourceIterator}, or + * {@code null} if the source iterator has not been initialized + */ + public Source.Position proposeStopPosition(Source.Progress proposedStopPosition) { + synchronized (sourceIteratorLock) { + if (sourceIterator == null) { + LOG.warn("Iterator has not been initialized, returning null stop position."); + return null; + } + return sourceIterator.updateStopPosition(proposedStopPosition); + } + } + + /** + * This is an observer on the instance of the source. Whenever source reads + * an element, update() gets called with the byte size of the element, which + * gets added up into the ReadOperation's byte counter. + */ + private class SourceObserver implements Observer { + @Override + public void update(Observable obs, Object obj) { + Preconditions.checkArgument(obs == source, "unexpected observable" + obs); + Preconditions.checkArgument(obj instanceof Long, "unexpected parameter object: " + obj); + byteCount.addValue((long) obj); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Receiver.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Receiver.java new file mode 100644 index 000000000000..f772ee4c2445 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Receiver.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * Abstract interface of things that accept inputs one at a time via process(). + */ +public interface Receiver { + /** + * Processes the element. + */ + void process(Object outputElem) throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReceivingOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReceivingOperation.java new file mode 100644 index 000000000000..60deea53fa9c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ReceivingOperation.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * The abstract base class for Operations that have inputs and + * implement process(). + */ +public abstract class ReceivingOperation extends Operation implements Receiver { + + public ReceivingOperation(String operationName, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, receivers, + counterPrefix, addCounterMutator, stateSampler); + } + + /** + * Adds an input to this Operation, coming from the given + * output of the given source Operation. + */ + public void attachInput(Operation source, int outputNum) { + checkUnstarted(); + OutputReceiver fanOut = source.receivers[outputNum]; + fanOut.addOutput(this); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleBatchReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleBatchReader.java new file mode 100644 index 000000000000..f5102dd14e05 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleBatchReader.java @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import java.io.IOException; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * ShuffleBatchReader provides an interface for reading a batch of + * key/value entries from a shuffle dataset. + */ +public interface ShuffleBatchReader { + /** The result returned by #read. */ + public static class Batch { + public final List entries; + @Nullable public final ShufflePosition nextStartPosition; + + public Batch(List entries, + @Nullable ShufflePosition nextStartPosition) { + this.entries = entries; + this.nextStartPosition = nextStartPosition; + } + } + + /** + * Reads a batch of data from a shuffle dataset. + * + * @param startPosition encodes the initial key from where to read. + * This parameter may be null, indicating that the read should start + * with the first key in the dataset. + * + * @param endPosition encodes the key "just past" the end of the + * range to be read; keys up to endPosition will be returned, but + * keys equal to or greater than endPosition will not. This + * parameter may be null, indicating that the read should end just + * past the last key in the dataset (that is, the last key in the + * dataset will be included in the read, as long as that key is + * greater than or equal to startPosition). + * + * @return the first {@link Batch} of entries + */ + public Batch read(@Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition) + throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntry.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntry.java new file mode 100644 index 000000000000..750c3ac5c71c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntry.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import java.util.Arrays; + +/** + * Entry written to/read from a shuffle dataset. + */ +public class ShuffleEntry { + final byte[] position; + final byte[] key; + final byte[] secondaryKey; + final byte[] value; + + public ShuffleEntry(byte[] key, byte[] secondaryKey, byte[] value) { + this.position = null; + this.key = key; + this.secondaryKey = secondaryKey; + this.value = value; + } + + public ShuffleEntry(byte[] position, byte[] key, byte[] secondaryKey, + byte[] value) { + this.position = position; + this.key = key; + this.secondaryKey = secondaryKey; + this.value = value; + } + + public byte[] getPosition() { + return position; + } + + public byte[] getKey() { + return key; + } + + public byte[] getSecondaryKey() { + return secondaryKey; + } + + public byte[] getValue() { + return value; + } + + public int length() { + return (position == null ? 0 : position.length) + + (key == null ? 0 : key.length) + + (secondaryKey == null ? 0 : secondaryKey.length) + + (value == null ? 0 : value.length); + } + + @Override + public String toString() { + return "ShuffleEntry(" + + byteArrayToString(position) + "," + + byteArrayToString(key) + "," + + byteArrayToString(secondaryKey) + "," + + byteArrayToString(value) + ")"; + } + + public static String byteArrayToString(byte[] bytes) { + // TODO: Use a more compact and readable representation, + // particularly for (nearly-)ascii keys and values. + return Arrays.toString(bytes); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof ShuffleEntry) { + ShuffleEntry that = (ShuffleEntry) o; + return (this.position == null ? that.position == null + : Arrays.equals(this.position, that.position)) + && (this.key == null ? that.key == null + : Arrays.equals(this.key, that.key)) + && (this.secondaryKey == null ? that.secondaryKey == null + : Arrays.equals(this.secondaryKey, that.secondaryKey)) + && (this.value == null ? that.value == null + : Arrays.equals(this.value, that.value)); + } + return false; + } + + @Override + public int hashCode() { + return getClass().hashCode() + + (position == null ? 0 : Arrays.hashCode(position)) + + (key == null ? 0 : Arrays.hashCode(key)) + + (secondaryKey == null ? 0 : Arrays.hashCode(secondaryKey)) + + (value == null ? 0 : Arrays.hashCode(value)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryReader.java new file mode 100644 index 000000000000..bbc5f47a4b8c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryReader.java @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.Reiterator; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * ShuffleEntryReader provides an interface for reading key/value + * entries from a shuffle dataset. + */ +@NotThreadSafe +public interface ShuffleEntryReader { + /** + * Returns an iterator which reads a range of entries from a shuffle dataset. + * + * @param startPosition encodes the initial key from where to read. + * This parameter may be null, indicating that the read should start + * with the first key in the dataset. + * + * @param endPosition encodes the key "just past" the end of the + * range to be read; keys up to endPosition will be returned, but + * keys equal to or greater than endPosition will not. This + * parameter may be null, indicating that the read should end just + * past the last key in the dataset (that is, the last key in the + * dataset will be included in the read, as long as that key is + * greater than or equal to startPosition). + * + * @return a {@link Reiterator} over the requested range of entries. + */ + public Reiterator read( + @Nullable ShufflePosition startPosition, + @Nullable ShufflePosition endPosition); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShufflePosition.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShufflePosition.java new file mode 100644 index 000000000000..c512269a4950 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/ShufflePosition.java @@ -0,0 +1,23 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +/** + * Represents a position in a stream of ShuffleEntries. + */ +public interface ShufflePosition { +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Sink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Sink.java new file mode 100644 index 000000000000..829fd1a39153 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Sink.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import java.io.IOException; + +/** + * Abstract base class for Sinks. + * + *

A Sink is written to by getting a SinkWriter and adding values to + * it. + * + * @param the type of the elements written to the sink + */ +public abstract class Sink { + /** + * Returns a Writer that allows writing to this Sink. + */ + public abstract SinkWriter writer() throws IOException; + + /** + * Writes to a Sink. + */ + public interface SinkWriter extends AutoCloseable { + /** + * Adds a value to the sink. Returns the size in bytes of the data written. + */ + public long add(E value) throws IOException; + + @Override + public void close() throws IOException; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Source.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Source.java new file mode 100644 index 000000000000..d50b93dc5419 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/Source.java @@ -0,0 +1,157 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.Observable; + +/** + * Abstract base class for Sources. + * + *

A Source is read from by getting an Iterator-like value and + * iterating through it. + * + * @param the type of the elements read from the source + */ +public abstract class Source extends Observable { + /** + * Returns a SourceIterator that allows reading from this source. + */ + public abstract SourceIterator iterator() throws IOException; + + /** + * A stateful iterator over the data in a Source. + */ + public interface SourceIterator extends AutoCloseable { + /** + * Returns whether the source has any more elements. Some sources, + * such as GroupingShuffleSource, invalidate the return value of + * the previous next() call during the call to hasNext(). + */ + public boolean hasNext() throws IOException; + + /** + * Returns the next element. + * + * @throws NoSuchElementException if there are no more elements + */ + public T next() throws IOException; + + /** + * Copies the current SourceIterator. + * + * @throws UnsupportedOperationException if the particular implementation + * does not support copy + * @throws IOException if copying the iterator involves IO that fails + */ + public SourceIterator copy() throws IOException; + + @Override + public void close() throws IOException; + + /** + * Returns a representation of how far this iterator is through the source. + * + *

This method is not required to be thread-safe, and it will not be + * called concurrently to any other methods. + * + * @return the progress, or {@code null} if no progress measure + * can be provided + */ + public Progress getProgress(); + + /** + * Attempts to update the stop position of the task with the proposed stop + * position and returns the actual new stop position. + * + *

If the source finds the proposed one is not a convenient position to + * stop, it can pick a different stop position. The {@code SourceIterator} + * should start returning {@code false} from {@code hasNext()} once it has + * passed its stop position. Subsequent stop position updates must be in + * non-increasing order within a task. + * + *

This method is not required to be thread-safe, and it will not be + * called concurrently to any other methods. + * + * @param proposedStopPosition a proposed position to stop + * iterating through the source + * @return the new stop position, or {@code null} on failure if the + * implementation does not support position updates. + */ + public Position updateStopPosition(Progress proposedStopPosition); + } + + /** An abstract base class for SourceIterator implementations. */ + public abstract static class AbstractSourceIterator + implements SourceIterator { + @Override + public SourceIterator copy() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void close() throws IOException { + // By default, nothing is needed for close. + } + + @Override + public Progress getProgress() { + return null; + } + + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + return null; + } + } + + /** + * A representation of how far a {@code SourceIterator} is through a + * {@code Source}. + * + *

The common worker framework does not interpret instances of + * this interface. But a tool-specific framework can make assumptions + * about the implementation, and so the concrete Source subclasses used + * by a tool-specific framework should match. + */ + public interface Progress { + } + + /** + * A representation of a position in an iteration through a + * {@code Source}. + * + *

See the comment on {@link Progress} for how instances of this + * interface are used by the rest of the framework. + */ + public interface Position { + } + + /** + * Utility method to notify observers about a new element, which has + * been read by this Source, and its size in bytes. Normally, there + * is only one observer, which is a ReadOperation that encapsules + * this Source. Derived classes must call this method whenever they + * read additional data, even if that element may never be returned + * from the corresponding source iterator. + */ + protected void notifyElementRead(long byteSize) { + setChanged(); + notifyObservers(byteSize); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java new file mode 100644 index 000000000000..91d90e9d2a05 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java @@ -0,0 +1,279 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.Timer; +import java.util.TimerTask; + +/** + * A StateSampler object may be used to obtain an approximate + * breakdown of the time spent by an execution context in various + * states, as a fraction of the total time. The sampling is taken at + * regular intervals, with adjustment for scheduling delay. + * + *

Thread-safe. + */ +public class StateSampler extends TimerTask implements AutoCloseable { + private final String prefix; + private CounterSet.AddCounterMutator counterSetMutator; + // Sampling period of internal Timer (thread). + public final long samplingPeriodMs; + public static final int DO_NOT_SAMPLE = -1; + public static final long DEFAULT_SAMPLING_PERIOD_MS = 200; + // Array of counters indexed by their state. + private ArrayList> countersByState = new ArrayList<>(); + // Map of state name to state. + private HashMap statesByName = new HashMap<>(); + // The current state. + private int currentState; + // The timestamp corresponding to the last state change or the last + // time the current state was sampled (and recorded). + private long stateTimestamp = 0; + + // When sampling this state, a stack trace is also logged. + private int stateToSampleThreadStacks = DO_NOT_SAMPLE; + // The thread that performed the last state transition. + private Thread sampledThread = null; + // The frequency with which the stack traces are logged, with respect + // to the sampling period. + private static final int SAMPLE_THREAD_STACK_FREQ = 10; + private int sampleThreadStackFreq = 0; + + // Using a fixed number of timers for all StateSampler objects. + private static final int NUM_TIMER_THREADS = 16; + // The timers is used for periodically sampling the states. + private static Timer[] timers = new Timer[NUM_TIMER_THREADS]; + static { + for (int i = 0; i < timers.length; ++i) { + timers[i] = new Timer("StateSampler_" + i, true /* is daemon */); + } + } + + /** + * Constructs a new {@link StateSampler} that can be used to obtain + * an approximate breakdown of the time spent by an execution + * context in various states, as a fraction of the total time. + * + * @param prefix the prefix of the counter names for the states + * @param counterSetMutator the {@link CounterSet.AddCounterMutator} + * used to create a counter for each distinct state + * @param samplingPeriodMs the sampling period in milliseconds + */ + public StateSampler(String prefix, + CounterSet.AddCounterMutator counterSetMutator, + long samplingPeriodMs) { + this.prefix = prefix; + this.counterSetMutator = counterSetMutator; + this.samplingPeriodMs = samplingPeriodMs; + currentState = DO_NOT_SAMPLE; + Random rand = new Random(); + int initialDelay = rand.nextInt((int) samplingPeriodMs); + timers[rand.nextInt(NUM_TIMER_THREADS)].scheduleAtFixedRate( + this, initialDelay, samplingPeriodMs); + stateTimestamp = System.currentTimeMillis(); + } + + /** + * Constructs a new {@link StateSampler} that can be used to obtain + * an approximate breakdown of the time spent by an execution + * context in various states, as a fraction of the total time. + * + * @param prefix the prefix of the counter names for the states + * @param counterSetMutator the {@link CounterSet.AddCounterMutator} + * used to create a counter for each distinct state + */ + public StateSampler(String prefix, + CounterSet.AddCounterMutator counterSetMutator) { + this(prefix, counterSetMutator, DEFAULT_SAMPLING_PERIOD_MS); + } + + private void printStackTrace(Thread thread) { + System.out.println("Sampled stack trace:"); + StackTraceElement[] stack = thread.getStackTrace(); + for (StackTraceElement elem : stack) { + System.out.println("\t" + elem.toString()); + } + } + + /** + * Selects a state for which the thread stacks will also be logged + * during the sampling. Useful for debugging. + * + * @param state name of the selected state + */ + public synchronized void setStateToSampleThreadStacks(int state) { + stateToSampleThreadStacks = state; + } + + @Override + public synchronized void run() { + long now = System.currentTimeMillis(); + if (currentState != DO_NOT_SAMPLE) { + countersByState.get(currentState).addValue(now - stateTimestamp); + if (sampledThread != null + && currentState == stateToSampleThreadStacks + && ++sampleThreadStackFreq >= SAMPLE_THREAD_STACK_FREQ) { + printStackTrace(sampledThread); + sampleThreadStackFreq = 0; + } + } + stateTimestamp = now; + } + + @Override + public void close() { + this.cancel(); // cancel the TimerTask + } + + /** + * Returns the state associated with a name; creating a new state if + * necessary. Using states instead of state names during state + * transitions is done for efficiency. + * + * @name the name for the state + * @return the state associated with the state name + */ + public int stateForName(String name) { + if (name.isEmpty()) { + return DO_NOT_SAMPLE; + } + + String counterName = prefix + name + "-msecs"; + synchronized (this) { + Integer state = statesByName.get(counterName); + if (state == null) { + Counter counter = counterSetMutator.addCounter( + Counter.longs(counterName, Counter.AggregationKind.SUM)); + state = countersByState.size(); + statesByName.put(name, state); + countersByState.add(counter); + } + return state; + } + } + + /** + * Sets the current thread state. + * + * @param state the new state to transition to + * @return the previous state + */ + public synchronized int setState(int state) { + // TODO: investigate whether this can be made cheaper, (e.g., + // using atomic operations). + int previousState = currentState; + currentState = state; + if (stateToSampleThreadStacks != DO_NOT_SAMPLE) { + sampledThread = Thread.currentThread(); + } + return previousState; + } + + /** + * Sets the current thread state. + * + * @param name the name of the new state to transition to + * @return the previous state + */ + public synchronized int setState(String name) { + return setState(stateForName(name)); + } + + /** + * Returns a tuple consisting of the current state and duration. + * + * @return a {@link Map.Entry} entry with current state and duration + */ + public synchronized Map.Entry getCurrentStateAndDuration() { + if (currentState == DO_NOT_SAMPLE) { + return new SimpleEntry<>("", 0L); + } + + Counter counter = countersByState.get(currentState); + return new SimpleEntry<>(counter.getName(), + counter.getAggregate(false) + + System.currentTimeMillis() - stateTimestamp); + } + + /** + * Get the duration for a given state. + * + * @param state the state whose duration is returned + * @return the duration of a given state + */ + public synchronized long getStateDuration(int state) { + Counter counter = countersByState.get(state); + return counter.getAggregate(false) + + (state == currentState + ? System.currentTimeMillis() - stateTimestamp : 0); + } + + /** + * Returns an AutoCloseable {@link ScopedState} that will perform a + * state transition to the given state, and will automatically reset + * the state to the prior state upon closing. + * + * @param state the new state to transition to + * @return a {@link ScopedState} that automatically resets the state + * to the prior state + */ + public synchronized ScopedState scopedState(int state) { + return new ScopedState(this, setState(state)); + } + + /** + * Returns an AutoCloseable {@link ScopedState} that will perform a + * state transition to the given state, and will automatically reset + * the state to the prior state upon closing. + * + * @param stateName the name of the new state + * @return a {@link ScopedState} that automatically resets the state + * to the prior state + */ + public synchronized ScopedState scopedState(String stateName) { + return new ScopedState(this, setState(stateName)); + } + + /** + * A nested class that is used to account for states and state + * transitions based on lexical scopes. + * + *

Thread-safe. + */ + public class ScopedState implements AutoCloseable { + private StateSampler sampler; + private int previousState; + + private ScopedState(StateSampler sampler, int previousState) { + this.sampler = sampler; + this.previousState = previousState; + } + + @Override + public void close() { + sampler.setState(previousState); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutor.java new file mode 100644 index 000000000000..63270b682ebc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutor.java @@ -0,0 +1,99 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.Metric; +import com.google.cloud.dataflow.sdk.util.common.Metric.DoubleMetric; + +import com.sun.management.OperatingSystemMXBean; + +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Abstract executor for WorkItem tasks. + */ +public abstract class WorkExecutor implements AutoCloseable { + /** The output counters for this task. */ + private final CounterSet outputCounters; + + /** + * OperatingSystemMXBean for reporting CPU usage. + * + * Uses com.sun.management.OperatingSystemMXBean instead of + * java.lang.management.OperatingSystemMXBean because the former supports + * getProcessCpuLoad(). + */ + private final OperatingSystemMXBean os; + + /** + * Constructs a new WorkExecutor task. + */ + public WorkExecutor(CounterSet outputCounters) { + this.outputCounters = outputCounters; + this.os = + (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean(); + } + + /** + * Returns the set of output counters for this task. + */ + public CounterSet getOutputCounters() { + return outputCounters; + } + + /** + * Returns a collection of output metrics for this task. + */ + public Collection> getOutputMetrics() { + List> outputMetrics = new ArrayList<>(); + outputMetrics.add(new DoubleMetric("CPU", os.getProcessCpuLoad())); + // More metrics as needed. + return outputMetrics; + } + + /** + * Executes the task. + */ + public abstract void execute() throws Exception; + + /** + * Returns the worker's current progress. + */ + public Source.Progress getWorkerProgress() throws Exception { + // By default, return null indicating worker progress not available. + return null; + } + + /** + * Proposes that the worker changes the stop position for the current work. + * Returns the new position if accepted, otherwise {@code null}. + */ + public Source.Position proposeStopPosition( + Source.Progress proposedStopPosition) throws Exception { + // By default, returns null indicating that no task splitting happens. + return null; + } + + @Override + public void close() throws Exception { + // By default, nothing to close or shut down. + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkProgressUpdater.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkProgressUpdater.java new file mode 100644 index 000000000000..c5222eb04a2f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkProgressUpdater.java @@ -0,0 +1,239 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * WorkProgressUpdater allows a work executor to send work progress + * updates to the worker service. The life-cycle of the + * WorkProgressUpdater is controlled externally through its + * {@link #startReportingProgress()} and + * {@link #stopReportingProgress()} methods. The updater queries the + * worker for progress updates and sends the updates to the worker + * service. The interval between two consecutive updates is + * controlled by the worker service through reporting interval hints + * sent back in the update response messages. To avoid update storms + * and monitoring staleness, the interval between two consecutive + * updates is also bound by {@link #MIN_REPORTING_INTERVAL_MILLIS} and + * {@link #MAX_REPORTING_INTERVAL_MILLIS}. + */ +@NotThreadSafe +public abstract class WorkProgressUpdater { + private static final Logger LOG = LoggerFactory.getLogger(WorkProgressUpdater.class); + + /** The default lease duration to request from the external worker service. */ + private static final long DEFAULT_LEASE_DURATION_MILLIS = 3 * 60 * 1000; + + /** The lease renewal RPC latency margin. */ + private static final long LEASE_RENEWAL_LATENCY_MARGIN = Long.valueOf( + System.getProperty("worker_lease_renewal_latency_margin", "5000")); + + /** + * The minimum period between two consecutive progress updates. Ensures the + * {@link WorkProgressUpdater} does not generate update storms. + */ + private static final long MIN_REPORTING_INTERVAL_MILLIS = Long.valueOf( + System.getProperty("minimum_worker_update_interval_millis", "5000")); + + /** + * The maximum period between two consecutive progress updates. Ensures the + * {@link WorkProgressUpdater} does not cause monitoring staleness. + */ + private static final long MAX_REPORTING_INTERVAL_MILLIS = 10 * 60 * 1000; + + /** Worker providing the work progress updates. */ + protected final WorkExecutor worker; + + /** Executor used to schedule work progress updates. */ + private final ScheduledExecutorService executor; + + /** The lease duration to request from the external worker service. */ + protected long requestedLeaseDurationMs; + + /** The time period until the next work progress update. */ + protected long progressReportIntervalMs; + + /** + * The stop position to report to the service in the next progress update, + * or {@code null} if there is nothing to report. + * In cases that there is no split request from service, or worker failed to + * split in response to the last received split request, the task stop + * position implicitly stays the same as it was before that last request + * (as a result of a prior split request), and on the next reportProgress + * we'll send the {@code null} as a stop position update, which is a no-op + * for the service. + */ + protected Source.Position stopPositionToService; + + public WorkProgressUpdater(WorkExecutor worker) { + this.worker = worker; + this.executor = Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("WorkProgressUpdater-%d") + .build()); + } + + /** + * Starts sending work progress updates to the worker service. + */ + public void startReportingProgress() { + // Send the initial work progress report half-way through the lease + // expiration. Subsequent intervals adapt to hints from the service. + long leaseRemainingTime = + leaseRemainingTime(getWorkUnitLeaseExpirationTimestamp()); + progressReportIntervalMs = nextProgressReportInterval( + leaseRemainingTime / 2, leaseRemainingTime); + requestedLeaseDurationMs = DEFAULT_LEASE_DURATION_MILLIS; + + LOG.info("Started reporting progress for work item: {}", workString()); + scheduleNextUpdate(); + } + + /** + * Stops sending work progress updates to the worker service. + * It may throw an exception if the final progress report fails to be sent for some reason. + */ + public void stopReportingProgress() throws Exception { + // TODO: Redesign to get rid of the executor and use a dedicated + // thread with a sleeper. Also unify with success/failure reporting. + + // Wait until there are no more progress updates in progress, then + // shut down. + synchronized (executor) { + executor.shutdownNow(); + } + + // We send a final progress report in case there was an unreported stop position update. + if (stopPositionToService != null) { + LOG.info("Sending final progress update with unreported stop position."); + reportProgressHelper(); // This call can fail with an exception + } + + LOG.info("Stopped reporting progress for work item: {}", workString()); + } + + /** + * Computes the time before sending the next work progress update making sure + * that it falls between the [{@link #MIN_REPORTING_INTERVAL_MILLIS}, + * {@link #MAX_REPORTING_INTERVAL_MILLIS}) interval. Makes an attempt to bound + * the result by the remaining lease time, with an RPC latency margin of + * {@link #LEASE_RENEWAL_LATENCY_MARGIN}. + * + * @param suggestedInterval the suggested progress report interval + * @param leaseRemainingTime milliseconds left before the work lease expires + * @return the time in milliseconds before sending the next progress update + */ + protected static long nextProgressReportInterval(long suggestedInterval, + long leaseRemainingTime) { + // Sanitize input in case we get a negative suggested time interval. + suggestedInterval = Math.max(0, suggestedInterval); + + // Try to send the next progress update before the next lease expiration + // allowing some RPC latency margin. + suggestedInterval = Math.min(suggestedInterval, + leaseRemainingTime - LEASE_RENEWAL_LATENCY_MARGIN); + + // Bound reporting interval to avoid staleness and progress update storms. + return Math.min(Math.max(MIN_REPORTING_INTERVAL_MILLIS, suggestedInterval), + MAX_REPORTING_INTERVAL_MILLIS); + } + + /** + * Schedules the next work progress update. + */ + private void scheduleNextUpdate() { + if (executor.isShutdown()) { + return; + } + executor.schedule(new Runnable() { + @Override + public void run() { + // Don't shut down while reporting progress. + synchronized (executor) { + if (executor.isShutdown()) { + return; + } + reportProgress(); + } + } + }, progressReportIntervalMs, TimeUnit.MILLISECONDS); + LOG.debug("Next work progress update for work item {} scheduled to occur in {} ms.", + workString(), progressReportIntervalMs); + } + + /** + * Reports the current work progress to the worker service. + */ + private void reportProgress() { + LOG.info("Updating progress on work item {}", workString()); + try { + reportProgressHelper(); + } catch (Throwable e) { + LOG.warn("Error reporting work progress update: ", e); + } finally { + scheduleNextUpdate(); + } + } + + /** + * Computes the amount of time left, in milliseconds, before a lease + * with the specified expiration timestamp expires. Returns zero if + * the lease has already expired. + */ + protected long leaseRemainingTime(long leaseExpirationTimestamp) { + long now = System.currentTimeMillis(); + if (leaseExpirationTimestamp < now) { + LOG.debug("Lease remaining time for {} is 0 ms.", workString()); + return 0; + } + LOG.debug("Lease remaining time for {} is {} ms.", + workString(), leaseExpirationTimestamp - now); + return leaseExpirationTimestamp - now; + } + + // Visible for testing. + public Source.Position getStopPosition() { + return stopPositionToService; + } + + /** + * Reports the current work progress to the worker service. + */ + protected abstract void reportProgressHelper() throws Exception; + + /** + * Returns the current work item's lease expiration timestamp. + */ + protected abstract long getWorkUnitLeaseExpirationTimestamp(); + + /** + * Returns a string representation of the work item whose progress + * is being updated, for use in logging messages. + */ + protected abstract String workString(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperation.java new file mode 100644 index 000000000000..6f8b2e586548 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperation.java @@ -0,0 +1,105 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * A write operation. + */ +public class WriteOperation extends ReceivingOperation { + /** + * The Sink this operation writes to. + */ + public final Sink sink; + + /** + * The total byte counter for all data written by this operation. + */ + final Counter byteCount; + + /** + * The Sink's writer this operation writes to, created by start(). + */ + Sink.SinkWriter writer; + + public WriteOperation(String operationName, + Sink sink, + OutputReceiver[] receivers, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(operationName, receivers, + counterPrefix, addCounterMutator, stateSampler); + this.sink = sink; + this.byteCount = addCounterMutator.addCounter( + Counter.longs(bytesCounterName(counterPrefix, operationName), SUM)); + } + + /** Invoked by tests. */ + public WriteOperation(Sink sink, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + this("WriteOperation", sink, new OutputReceiver[]{ }, + counterPrefix, addCounterMutator, stateSampler); + } + + protected String bytesCounterName(String counterPrefix, + String operationName) { + return operationName + "-ByteCount"; + } + + public Sink getSink() { + return sink; + } + + @Override + public void start() throws Exception { + try (StateSampler.ScopedState start = + stateSampler.scopedState(startState)) { + super.start(); + writer = sink.writer(); + } + } + + @Override + public void process(Object outputElem) throws Exception { + try (StateSampler.ScopedState process = + stateSampler.scopedState(processState)) { + checkStarted(); + byteCount.addValue(writer.add(outputElem)); + } + } + + @Override + public void finish() throws Exception { + try (StateSampler.ScopedState finish = + stateSampler.scopedState(finishState)) { + checkStarted(); + writer.close(); + super.finish(); + } + } + + public Counter getByteCount() { + return byteCount; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java new file mode 100644 index 000000000000..1bef723c9ac7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used to implement the harness that runs user code. **/ +package com.google.cloud.dataflow.sdk.util.common.worker; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java new file mode 100644 index 000000000000..f1da8b767ef2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java @@ -0,0 +1,617 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsfs; + +import com.google.api.client.util.Preconditions; +import com.google.api.client.util.Strings; +import com.google.api.services.storage.model.StorageObject; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.FileSystem; +import java.nio.file.LinkOption; +import java.nio.file.Path; +import java.nio.file.WatchEvent; +import java.nio.file.WatchKey; +import java.nio.file.WatchService; +import java.util.Iterator; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Implements the Java NIO {@link Path} API for Google Cloud Storage paths. + * + *

GcsPath uses a slash ('/') as a directory separator. Below is + * a summary of how slashes are treated: + *

    + *
  • A GCS bucket may not contain a slash. An object may contain zero or + * more slashes. + *
  • A trailing slash always indicates a directory, which is compliant + * with POSIX.1-2008. + *
  • Slashes separate components of a path. Empty components are allowed, + * which is represented as repeated slashes. An empty component always + * refers to a directory, and always ends in a slash. + *
  • {@link #getParent()}} always returns a path ending in a slash, as the + * parent of a GcsPath is always a directory. + *
  • Use {@link #resolve(String)} to append elements to a GcsPath -- this + * applies the rules consistently and is highly recommended over any + * custom string concatenation. + *
+ * + *

GcsPath treats all GCS objects and buckets as belonging to the same + * filesystem, so the root of a GcsPath is the GcsPath bucket="", object="". + * + *

Relative paths are not associated with any bucket. This matches common + * treatment of Path in which relative paths can be constructed from one + * filesystem and appended to another filesystem. + * + * @see Java Tutorials: Path Operations + */ +public class GcsPath implements Path { + + public static final String SCHEME = "gs"; + + /** + * Creates a GcsPath from a URI. + * + *

The URI must be in the form {@code gs://[bucket]/[path]}, and may not + * contain a port, user info, a query, or a fragment. + */ + public static GcsPath fromUri(URI uri) { + Preconditions.checkArgument(uri.getScheme().equalsIgnoreCase(SCHEME), + "URI: %s is not a GCS URI", uri); + Preconditions.checkArgument(uri.getPort() == -1, + "GCS URI may not specify port: %s (%i)", uri, uri.getPort()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getUserInfo()), + "GCS URI may not specify userInfo: %s (%s)", uri, uri.getUserInfo()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getQuery()), + "GCS URI may not specify query: %s (%s)", uri, uri.getQuery()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getFragment()), + "GCS URI may not specify fragment: %s (%s)", uri, uri.getFragment()); + + return fromUri(uri.toString()); + } + + /** + * Pattern which is used to parse a GCS URL. + * + *

This is used to separate the components. Verification is handled + * separately. + */ + private static final Pattern GCS_URI = + Pattern.compile("(?[^:]+)://(?[^/]+)(/(?.*))?"); + + /** + * Creates a GcsPath from a URI in string form. + * + *

This does not use URI parsing, which means it may accept patterns that + * the URI parser would not accept. + */ + public static GcsPath fromUri(String uri) { + Matcher m = GCS_URI.matcher(uri); + Preconditions.checkArgument(m.matches(), "Invalid GCS URI: %s", uri); + + Preconditions.checkArgument(m.group("SCHEME").equalsIgnoreCase(SCHEME), + "URI: %s is not a GCS URI", uri); + return new GcsPath(null, m.group("BUCKET"), m.group("OBJECT")); + } + + /** + * Pattern which is used to parse a GCS resource name. + */ + private static final Pattern GCS_RESOURCE_NAME = + Pattern.compile("storage.googleapis.com/(?[^/]+)(/(?.*))?"); + + /** + * Creates a GcsPath from a OnePlatform resource name in string form. + */ + public static GcsPath fromResourceName(String name) { + Matcher m = GCS_RESOURCE_NAME.matcher(name); + Preconditions.checkArgument(m.matches(), "Invalid GCS resource name: %s", name); + + return new GcsPath(null, m.group("BUCKET"), m.group("OBJECT")); + } + + /** + * Creates a GcsPath from a {@linkplain StorageObject}. + */ + public static GcsPath fromObject(StorageObject object) { + return new GcsPath(null, object.getBucket(), object.getName()); + } + + /** + * Creates a GcsPath from bucket and object components. + * + *

A GcsPath without a bucket name is treated as a relative path, which + * is a path component with no linkage to the root element. This is similar + * to a Unix path which does not begin with the root marker (a slash). + * GCS has different naming constraints and APIs for working with buckets and + * objects, so these two concepts are kept separate to avoid accidental + * attempts to treat objects as buckets, or vice versa, as much as possible. + * + *

A GcsPath without an object name is a bucket reference. + * A bucket is always a directory, which could be used to lookup or add + * files to a bucket, but could not be opened as a file. + * + *

A GcsPath containing neither bucket or object names is treated as + * the root of the GCS filesystem. A listing on the root element would return + * the buckets available to the user. + * + *

If {@code null} is passed as either parameter, it is converted to an + * empty string internally for consistency. There is no distinction between + * an empty string and a {@code null}, as neither are allowed by GCS. + * + * @param bucket a GCS bucket name, or none ({@code null} or an empty string) + * if the object is not associated with a bucket + * (e.g. relative paths or the root node). + * @param object a GCS object path, or none ({@code null} or an empty string) + * for no object. + */ + public static GcsPath fromComponents(@Nullable String bucket, + @Nullable String object) { + return new GcsPath(null, bucket, object); + } + + @Nullable + private FileSystem fs; + @Nonnull + private final String bucket; + @Nonnull + private final String object; + + /** + * Constructs a GcsPath. + * + * @param fs the associated FileSystem, if any + * @param bucket the associated bucket, or none ({@code null} or an empty + * string) for a relative path component + * @param object the object, which is a fully-qualified object name if bucket + * was also provided, or none ({@code null} or an empty string) + * for no object + * @throws java.lang.IllegalArgumentException if the bucket of object names + * are invalid. + */ + public GcsPath(@Nullable FileSystem fs, + @Nullable String bucket, + @Nullable String object) { + if (bucket == null) { + bucket = ""; + } + Preconditions.checkArgument(!bucket.contains("/"), + "GCS bucket may not contain a slash"); + Preconditions + .checkArgument(bucket.isEmpty() + || bucket.matches("[a-z0-9][-_a-z0-9.]+[a-z0-9]"), + "GCS bucket names must contain only lowercase letters, numbers, " + + "dashes (-), underscores (_), and dots (.). Bucket names " + + "must start and end with a number or letter. " + + "See https://developers.google.com/storage/docs/bucketnaming " + + "for more details. Bucket name: " + bucket); + + if (object == null) { + object = ""; + } + Preconditions.checkArgument( + object.indexOf('\n') < 0 && object.indexOf('\r') < 0, + "GCS object names must not contain Carriage Return or " + + "Line Feed characters."); + + this.fs = fs; + this.bucket = bucket; + this.object = object; + } + + /** + * Returns the bucket name associated with this GCS path, or an empty string + * if this is a relative path component. + */ + public String getBucket() { + return bucket; + } + + /** + * Returns the object name associated with this GCS path, or an empty string + * if no object is specified. + */ + public String getObject() { + return object; + } + + public void setFileSystem(FileSystem fs) { + this.fs = fs; + } + + @Override + public FileSystem getFileSystem() { + return fs; + } + + // Absolute paths are those which have a bucket and the root path. + @Override + public boolean isAbsolute() { + return !bucket.isEmpty() || object.isEmpty(); + } + + @Override + public GcsPath getRoot() { + return new GcsPath(fs, "", ""); + } + + @Override + public GcsPath getFileName() { + throw new UnsupportedOperationException(); + } + + /** + * Returns the parent path, or {@code null} if this path does not + * have a parent. + * + *

Returns a path which ends in '/', as the parent path always refers to + * a directory. + */ + @Override + public GcsPath getParent() { + if (bucket.isEmpty() && object.isEmpty()) { + // The root path has no parent, by definition. + return null; + } + + if (object.isEmpty()) { + // A GCS bucket. All buckets come from a common root. + return getRoot(); + } + + // Skip last character, in case it is a trailing slash. + int i = object.lastIndexOf('/', object.length() - 2); + if (i <= 0) { + if (bucket.isEmpty()) { + // Relative paths are not attached to the root node. + return null; + } + return new GcsPath(fs, bucket, ""); + } + + // Retain trailing slash. + return new GcsPath(fs, bucket, object.substring(0, i + 1)); + } + + @Override + public int getNameCount() { + int count = bucket.isEmpty() ? 0 : 1; + if (object.isEmpty()) { + return count; + } + + // Add another for each separator found. + int index = -1; + while ((index = object.indexOf('/', index + 1)) != -1) { + count++; + } + + return object.endsWith("/") ? count : count + 1; + } + + @Override + public GcsPath getName(int count) { + Preconditions.checkArgument(count >= 0); + + Iterator iterator = iterator(); + for (int i = 0; i < count; ++i) { + Preconditions.checkArgument(iterator.hasNext()); + iterator.next(); + } + + Preconditions.checkArgument(iterator.hasNext()); + return (GcsPath) iterator.next(); + } + + @Override + public GcsPath subpath(int beginIndex, int endIndex) { + Preconditions.checkArgument(beginIndex >= 0); + Preconditions.checkArgument(endIndex > beginIndex); + + Iterator iterator = iterator(); + for (int i = 0; i < beginIndex; ++i) { + Preconditions.checkArgument(iterator.hasNext()); + iterator.next(); + } + + GcsPath path = null; + while (beginIndex < endIndex) { + Preconditions.checkArgument(iterator.hasNext()); + if (path == null) { + path = (GcsPath) iterator.next(); + } else { + path = path.resolve(iterator.next()); + } + ++beginIndex; + } + + return path; + } + + @Override + public boolean startsWith(Path other) { + if (other instanceof GcsPath) { + GcsPath gcsPath = (GcsPath) other; + return startsWith(gcsPath.bucketAndObject()); + } else { + return startsWith(other.toString()); + } + } + + @Override + public boolean startsWith(String prefix) { + return bucketAndObject().startsWith(prefix); + } + + @Override + public boolean endsWith(Path other) { + if (other instanceof GcsPath) { + GcsPath gcsPath = (GcsPath) other; + return endsWith(gcsPath.bucketAndObject()); + } else { + return endsWith(other.toString()); + } + } + + @Override + public boolean endsWith(String suffix) { + return bucketAndObject().endsWith(suffix); + } + + // TODO: support "." and ".." path components? + @Override + public GcsPath normalize() { return this; } + + @Override + public GcsPath resolve(Path other) { + if (other instanceof GcsPath) { + GcsPath path = (GcsPath) other; + if (path.isAbsolute()) { + return path; + } else { + return resolve(path.getObject()); + } + } else { + return resolve(other.toString()); + } + } + + @Override + public GcsPath resolve(String other) { + if (bucket.isEmpty() && object.isEmpty()) { + // Resolve on a root path is equivalent to looking up a bucket and object. + other = SCHEME + "://" + other; + } + + if (other.startsWith(SCHEME + "://")) { + GcsPath path = GcsPath.fromUri(other); + path.setFileSystem(getFileSystem()); + return path; + } + + if (other.isEmpty()) { + // An empty component MUST refer to a directory. + other = "/"; + } + + if (object.isEmpty()) { + return new GcsPath(fs, bucket, other); + } else if (object.endsWith("/")) { + return new GcsPath(fs, bucket, object + other); + } else { + return new GcsPath(fs, bucket, object + "/" + other); + } + } + + @Override + public Path resolveSibling(Path other) { + throw new UnsupportedOperationException(); + } + + @Override + public Path resolveSibling(String other) { + throw new UnsupportedOperationException(); + } + + @Override + public Path relativize(Path other) { + throw new UnsupportedOperationException(); + } + + @Override + public GcsPath toAbsolutePath() { + return this; + } + + @Override + public GcsPath toRealPath(LinkOption... options) throws IOException { + return this; + } + + @Override + public File toFile() { + throw new UnsupportedOperationException(); + } + + @Override + public WatchKey register(WatchService watcher, WatchEvent.Kind[] events, + WatchEvent.Modifier... modifiers) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public WatchKey register(WatchService watcher, WatchEvent.Kind... events) + throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + return new NameIterator(fs, !bucket.isEmpty(), bucketAndObject()); + } + + private static class NameIterator implements Iterator { + private final FileSystem fs; + private boolean fullPath; + private String name; + + NameIterator(FileSystem fs, boolean fullPath, String name) { + this.fs = fs; + this.fullPath = fullPath; + this.name = name; + } + + @Override + public boolean hasNext() { + return !Strings.isNullOrEmpty(name); + } + + @Override + public GcsPath next() { + int i = name.indexOf('/'); + String component; + if (i >= 0) { + component = name.substring(0, i); + name = name.substring(i + 1); + } else { + component = name; + name = null; + } + if (fullPath) { + fullPath = false; + return new GcsPath(fs, component, ""); + } else { + // Relative paths have no bucket. + return new GcsPath(fs, "", component); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + @Override + public int compareTo(Path other) { + if (!(other instanceof GcsPath)) { + throw new ClassCastException(); + } + + GcsPath path = (GcsPath) other; + int b = bucket.compareTo(path.bucket); + if (b != 0) { + return b; + } + + // Compare a component at a time, so that the separator char doesn't + // get compared against component contents. Eg, "a/b" < "a-1/b". + Iterator left = iterator(); + Iterator right = path.iterator(); + + while (left.hasNext() && right.hasNext()) { + String leftStr = left.next().toString(); + String rightStr = right.next().toString(); + int c = leftStr.compareTo(rightStr); + if (c != 0) { + return c; + } + } + + if (!left.hasNext() && !right.hasNext()) { + return 0; + } else { + return left.hasNext() ? 1 : -1; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + GcsPath paths = (GcsPath) o; + return bucket.equals(paths.bucket) && object.equals(paths.object); + } + + @Override + public int hashCode() { + int result = bucket.hashCode(); + result = 31 * result + object.hashCode(); + return result; + } + + @Override + public String toString() { + if (!isAbsolute()) { + return object; + } + StringBuilder sb = new StringBuilder(); + sb.append(SCHEME) + .append("://"); + if (!bucket.isEmpty()) { + sb.append(bucket) + .append('/'); + } + sb.append(object); + return sb.toString(); + } + + // TODO: Consider using resource names for all GCS paths used by the SDK. + public String toResourceName() { + StringBuilder sb = new StringBuilder(); + sb.append("storage.googleapis.com/"); + if (!bucket.isEmpty()) { + sb.append(bucket).append('/'); + } + sb.append(object); + return sb.toString(); + } + + @Override + public URI toUri() { + try { + return new URI(SCHEME, "//" + bucketAndObject(), null); + } catch (URISyntaxException e) { + throw new RuntimeException("Unable to create URI for GCS path " + this); + } + } + + private String bucketAndObject() { + if (bucket.isEmpty()) { + return object; + } else { + return bucket + "/" + object; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java new file mode 100644 index 000000000000..6784109e82af --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used to interact with Google Cloud Storage. **/ +package com.google.cloud.dataflow.sdk.util.gcsfs; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/ClientRequestHelper.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/ClientRequestHelper.java new file mode 100644 index 000000000000..155dd79f795b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/ClientRequestHelper.java @@ -0,0 +1,40 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.googleapis.services.AbstractGoogleClientRequest; +import com.google.api.client.http.HttpHeaders; + +/** + * ClientRequestHelper provides wrapper methods around final methods of AbstractGoogleClientRequest + * to allow overriding them if necessary. Typically should be used for testing purposes only. + */ +public class ClientRequestHelper { + /** + * Wraps AbstractGoogleClientRequest.getRequestHeaders(). + */ + public HttpHeaders getRequestHeaders(AbstractGoogleClientRequest clientRequest) { + return clientRequest.getRequestHeaders(); + } + + /** + * Wraps AbstractGoogleClientRequest.getMediaHttpUploader(). + */ + public void setChunkSize(AbstractGoogleClientRequest clientRequest, int chunkSize) { + clientRequest.getMediaHttpUploader().setChunkSize(chunkSize); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageExceptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageExceptions.java new file mode 100644 index 000000000000..5535a90826a9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageExceptions.java @@ -0,0 +1,82 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.util.Preconditions; +import com.google.common.base.Strings; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.List; + +/** + * Miscellaneous helper methods for standardizing the types of exceptions thrown by the various + * GCS-based FileSystems. + */ +public class GoogleCloudStorageExceptions { + /** + * Creates FileNotFoundException with suitable message for a GCS bucket or object. + */ + public static FileNotFoundException getFileNotFoundException( + String bucketName, String objectName) { + Preconditions.checkArgument(!Strings.isNullOrEmpty(bucketName), + "bucketName must not be null or empty"); + if (objectName == null) { + objectName = ""; + } + return new FileNotFoundException( + String.format("Item not found: %s/%s", bucketName, objectName)); + } + + /** + * Creates a composite IOException out of multiple IOExceptions. If there is only a single + * {@code innerException}, it will be returned as-is without wrapping into an outer exception. + * it. + */ + public static IOException createCompositeException( + List innerExceptions) { + Preconditions.checkArgument(innerExceptions != null, + "innerExceptions must not be null"); + Preconditions.checkArgument(innerExceptions.size() > 0, + "innerExceptions must contain at least one element"); + + if (innerExceptions.size() == 1) { + return innerExceptions.get(0); + } + + IOException combined = new IOException("Multiple IOExceptions."); + for (IOException inner : innerExceptions) { + combined.addSuppressed(inner); + } + return combined; + } + + /** + * Wraps the given IOException into another IOException, adding the given error message and a + * reference to the supplied bucket and object. It allows one to know which bucket and object + * were being accessed when the exception occurred for an operation. + */ + public static IOException wrapException(IOException e, String message, + String bucketName, String objectName) { + String name = "bucket: " + bucketName; + if (!Strings.isNullOrEmpty(objectName)) { + name += ", object: " + objectName; + } + String fullMessage = String.format("%s: %s", message, name); + return new IOException(fullMessage, e); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageReadChannel.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageReadChannel.java new file mode 100644 index 000000000000..a3d9b65347b2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageReadChannel.java @@ -0,0 +1,538 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.http.HttpResponse; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Preconditions; +import com.google.api.client.util.Sleeper; +import com.google.api.services.storage.Storage; +import com.google.cloud.dataflow.sdk.util.ApiErrorExtractor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.util.regex.Pattern; + +/** + * Provides seekable read access to GCS. + */ +public class GoogleCloudStorageReadChannel implements SeekableByteChannel { + // Logger. + private static final Logger LOG = LoggerFactory.getLogger(GoogleCloudStorageReadChannel.class); + + // Used to separate elements of a Content-Range + private static final Pattern SLASH = Pattern.compile("/"); + + // GCS access instance. + private Storage gcs; + + // Name of the bucket containing the object being read. + private String bucketName; + + // Name of the object being read. + private String objectName; + + // Read channel. + private ReadableByteChannel readChannel; + + // True if this channel is open, false otherwise. + private boolean channelIsOpen; + + // Current read position in the channel. + private long currentPosition = -1; + + // When a caller calls position(long) to set stream position, we record the target position + // and defer the actual seek operation until the caller tries to read from the channel. + // This allows us to avoid an unnecessary seek to position 0 that would take place on creation + // of this instance in cases where caller intends to start reading at some other offset. + // If lazySeekPending is set to true, it indicates that a target position has been set + // but the actual seek operation is still pending. + private boolean lazySeekPending; + + // Size of the object being read. + private long size = -1; + + // Maximum number of automatic retries when reading from the underlying channel without making + // progress; each time at least one byte is successfully read, the counter of attempted retries + // is reset. + // TODO: Wire this setting out to GHFS; it should correspond to adding the wiring for + // setting the equivalent value inside HttpRequest.java which determines the low-level retries + // during "execute()" calls. The default in HttpRequest.java is also 10. + private int maxRetries = 10; + + // Helper delegate for turning IOExceptions from API calls into higher-level semantics. + private final ApiErrorExtractor errorExtractor; + + // Sleeper used for waiting between retries. + private Sleeper sleeper = Sleeper.DEFAULT; + + // The clock used by ExponentialBackOff to determine when the maximum total elapsed time has + // passed doing a series of retries. + private NanoClock clock = NanoClock.SYSTEM; + + // Lazily initialized BackOff for sleeping between retries; only ever initialized if a retry is + // necessary. + private BackOff backOff = null; + + // Settings used for instantiating the default BackOff used for determining wait time between + // retries. TODO: Wire these out to be settable by the Hadoop configs. + // The number of milliseconds to wait before the very first retry in a series of retries. + public static final int DEFAULT_BACKOFF_INITIAL_INTERVAL_MILLIS = 200; + + // The amount of jitter introduced when computing the next retry sleep interval so that when + // many clients are retrying, they don't all retry at the same time. + public static final double DEFAULT_BACKOFF_RANDOMIZATION_FACTOR = 0.5; + + // The base of the exponent used for exponential backoff; each subsequent sleep interval is + // roughly this many times the previous interval. + public static final double DEFAULT_BACKOFF_MULTIPLIER = 1.5; + + // The maximum amount of sleep between retries; at this point, there will be no further + // exponential backoff. This prevents intervals from growing unreasonably large. + public static final int DEFAULT_BACKOFF_MAX_INTERVAL_MILLIS = 10 * 1000; + + // The maximum total time elapsed since the first retry over the course of a series of retries. + // This makes it easier to bound the maximum time it takes to respond to a permanent failure + // without having to calculate the summation of a series of exponentiated intervals while + // accounting for the randomization of backoff intervals. + public static final int DEFAULT_BACKOFF_MAX_ELAPSED_TIME_MILLIS = 2 * 60 * 1000; + + // ClientRequestHelper to be used instead of calling final methods in client requests. + private static ClientRequestHelper clientRequestHelper = new ClientRequestHelper(); + + /** + * Constructs an instance of GoogleCloudStorageReadChannel. + * + * @param gcs storage object instance + * @param bucketName name of the bucket containing the object to read + * @param objectName name of the object to read + * @throws java.io.FileNotFoundException if the given object does not exist + * @throws IOException on IO error + */ + public GoogleCloudStorageReadChannel( + Storage gcs, String bucketName, String objectName, ApiErrorExtractor errorExtractor) + throws IOException { + this.gcs = gcs; + this.bucketName = bucketName; + this.objectName = objectName; + this.errorExtractor = errorExtractor; + channelIsOpen = true; + position(0); + } + + /** + * Constructs an instance of GoogleCloudStorageReadChannel. + * Used for unit testing only. Do not use elsewhere. + * + * @throws IOException on IO error + */ + GoogleCloudStorageReadChannel() + throws IOException { + this.errorExtractor = null; + channelIsOpen = true; + position(0); + } + + /** + * Sets the ClientRequestHelper to be used instead of calling final methods in client requests. + */ + static void setClientRequestHelper(ClientRequestHelper helper) { + clientRequestHelper = helper; + } + + /** + * Sets the Sleeper used for sleeping between retries. + */ + void setSleeper(Sleeper sleeper) { + Preconditions.checkArgument(sleeper != null, "sleeper must not be null!"); + this.sleeper = sleeper; + } + + /** + * Sets the clock to be used for determining when max total time has elapsed doing retries. + */ + void setNanoClock(NanoClock clock) { + Preconditions.checkArgument(clock != null, "clock must not be null!"); + this.clock = clock; + } + + /** + * Sets the backoff for determining sleep duration between retries. + * + * @param backOff May be null to force the next usage to auto-initialize with default settings. + */ + void setBackOff(BackOff backOff) { + this.backOff = backOff; + } + + /** + * Gets the backoff used for determining sleep duration between retries. May be null if it was + * never lazily initialized. + */ + BackOff getBackOff() { + return backOff; + } + + /** + * Helper for initializing the BackOff used for retries. + */ + private BackOff createBackOff() { + return new ExponentialBackOff.Builder() + .setInitialIntervalMillis(DEFAULT_BACKOFF_INITIAL_INTERVAL_MILLIS) + .setRandomizationFactor(DEFAULT_BACKOFF_RANDOMIZATION_FACTOR) + .setMultiplier(DEFAULT_BACKOFF_MULTIPLIER) + .setMaxIntervalMillis(DEFAULT_BACKOFF_MAX_INTERVAL_MILLIS) + .setMaxElapsedTimeMillis(DEFAULT_BACKOFF_MAX_ELAPSED_TIME_MILLIS) + .setNanoClock(clock) + .build(); + } + + /** + * Sets the number of times to automatically retry by re-opening the underlying readChannel + * whenever an exception occurs while reading from it. The count of attempted retries is reset + * whenever at least one byte is successfully read, so this number of retries refers to retries + * made without achieving any forward progress. + */ + public void setMaxRetries(int maxRetries) { + this.maxRetries = maxRetries; + } + + /** + * Reads from this channel and stores read data in the given buffer. + * + * @param buffer buffer to read data into + * @return number of bytes read or -1 on end-of-stream + * @throws java.io.IOException on IO error + */ + @Override + public int read(ByteBuffer buffer) + throws IOException { + throwIfNotOpen(); + + // Don't try to read if the buffer has no space. + if (buffer.remaining() == 0) { + return 0; + } + + // Perform a lazy seek if not done already. + performLazySeek(); + + int totalBytesRead = 0; + int retriesAttempted = 0; + + // We read from a streaming source. We may not get all the bytes we asked for + // in the first read. Therefore, loop till we either read the required number of + // bytes or we reach end-of-stream. + do { + int remainingBeforeRead = buffer.remaining(); + try { + int numBytesRead = readChannel.read(buffer); + Preconditions.checkState(numBytesRead != 0, "Read 0 bytes without blocking!"); + if (numBytesRead < 0) { + break; + } + totalBytesRead += numBytesRead; + currentPosition += numBytesRead; + + // The count of retriesAttempted is per low-level readChannel.read call; each time we make + // progress we reset the retry counter. + retriesAttempted = 0; + } catch (IOException ioe) { + // TODO: Refactor any reusable logic for retries into a separate RetryHelper class. + if (retriesAttempted == maxRetries) { + LOG.warn("Already attempted max of {} retries while reading '{}'; throwing exception.", + maxRetries, StorageResourceId.createReadableString(bucketName, objectName)); + throw ioe; + } else { + if (retriesAttempted == 0) { + // If this is the first of a series of retries, we also want to reset the backOff + // to have fresh initial values. + if (backOff == null) { + backOff = createBackOff(); + } else { + backOff.reset(); + } + } + + ++retriesAttempted; + LOG.warn("Got exception while reading '{}'; retry # {}. Sleeping...", + StorageResourceId.createReadableString(bucketName, objectName), + retriesAttempted, ioe); + + try { + boolean backOffSuccessful = BackOffUtils.next(sleeper, backOff); + if (!backOffSuccessful) { + LOG.warn("BackOff returned false; maximum total elapsed time exhausted. Giving up " + + "after {} retries for '{}'", retriesAttempted, + StorageResourceId.createReadableString(bucketName, objectName)); + throw ioe; + } + } catch (InterruptedException ie) { + LOG.warn("Interrupted while sleeping before retry." + + "Giving up after {} retries for '{}'", retriesAttempted, + StorageResourceId.createReadableString(bucketName, objectName)); + ioe.addSuppressed(ie); + throw ioe; + } + LOG.info("Done sleeping before retry for '{}'; retry # {}.", + StorageResourceId.createReadableString(bucketName, objectName), + retriesAttempted); + + if (buffer.remaining() != remainingBeforeRead) { + int partialRead = remainingBeforeRead - buffer.remaining(); + LOG.info("Despite exception, had partial read of {} bytes; resetting retry count.", + partialRead); + retriesAttempted = 0; + totalBytesRead += partialRead; + currentPosition += partialRead; + } + + // Force the stream to be reopened by seeking to the current position. + long newPosition = currentPosition; + currentPosition = -1; + position(newPosition); + performLazySeek(); + } + } + } while (buffer.remaining() > 0); + + // If this method was called when the stream was already at EOF + // (indicated by totalBytesRead == 0) then return EOF else, + // return the number of bytes read. + return (totalBytesRead == 0) ? -1 : totalBytesRead; + } + + @Override + public int write(ByteBuffer src) throws IOException { + throw new UnsupportedOperationException("Cannot mutate read-only channel"); + } + + /** + * Tells whether this channel is open. + * + * @return a value indicating whether this channel is open + */ + @Override + public boolean isOpen() { + return channelIsOpen; + } + + /** + * Closes this channel. + * + * @throws IOException on IO error + */ + @Override + public void close() + throws IOException { + throwIfNotOpen(); + channelIsOpen = false; + if (readChannel != null) { + readChannel.close(); + } + } + + /** + * Returns this channel's current position. + * + * @return this channel's current position + */ + @Override + public long position() + throws IOException { + throwIfNotOpen(); + return currentPosition; + } + + /** + * Sets this channel's position. + * + * @param newPosition the new position, counting the number of bytes from the beginning. + * @return this channel instance + * @throws java.io.FileNotFoundException if the underlying object does not exist. + * @throws IOException on IO error + */ + @Override + public SeekableByteChannel position(long newPosition) + throws IOException { + throwIfNotOpen(); + + // If the position has not changed, avoid the expensive operation. + if (newPosition == currentPosition) { + return this; + } + + validatePosition(newPosition); + currentPosition = newPosition; + lazySeekPending = true; + return this; + } + + /** + * Returns size of the object to which this channel is connected. + * + * @return size of the object to which this channel is connected + * @throws IOException on IO error + */ + @Override + public long size() + throws IOException { + throwIfNotOpen(); + // Perform a lazy seek if not done already so that size of this channel is set correctly. + performLazySeek(); + return size; + } + + @Override + public SeekableByteChannel truncate(long size) throws IOException { + throw new UnsupportedOperationException("Cannot mutate read-only channel"); + } + + /** + * Sets size of this channel to the given value. + */ + protected void setSize(long size) { + this.size = size; + } + + /** + * Validates that the given position is valid for this channel. + */ + protected void validatePosition(long newPosition) { + // Validate: 0 <= newPosition + if (newPosition < 0) { + throw new IllegalArgumentException( + String.format("Invalid seek offset: position value (%d) must be >= 0", newPosition)); + } + + // Validate: newPosition < size + // Note that we access this.size directly rather than calling size() to avoid initiating + // lazy seek that leads to recursive error. We validate newPosition < size only when size of + // this channel has been computed by a prior call. This means that position could be + // potentially set to an invalid value (>= size) by position(long). However, that error + // gets caught during lazy seek. + if ((size >= 0) && (newPosition >= size)) { + throw new IllegalArgumentException( + String.format( + "Invalid seek offset: position value (%d) must be between 0 and %d", + newPosition, size)); + } + } + + /** + * Seeks to the given position in the underlying stream. + * + * Note: Seek is an expensive operation because a new stream is opened each time. + * + * @throws java.io.FileNotFoundException if the underlying object does not exist. + * @throws IOException on IO error + */ + private void performLazySeek() + throws IOException { + + // Return quickly if there is no pending seek operation. + if (!lazySeekPending) { + return; + } + + // Close the underlying channel if it is open. + if (readChannel != null) { + readChannel.close(); + } + + InputStream objectContentStream = openStreamAndSetSize(currentPosition); + readChannel = Channels.newChannel(objectContentStream); + lazySeekPending = false; + } + + /** + * Opens the underlying stream, sets its position to the given value and sets size based on + * stream content size. + * + * @param newPosition position to seek into the new stream. + * @throws IOException on IO error + */ + protected InputStream openStreamAndSetSize(long newPosition) + throws IOException { + validatePosition(newPosition); + Storage.Objects.Get getObject = gcs.objects().get(bucketName, objectName); + // Set the range on the existing request headers which may have been initialized with things + // like user-agent already. + clientRequestHelper.getRequestHeaders(getObject) + .setRange(String.format("bytes=%d-", newPosition)); + HttpResponse response; + try { + response = getObject.executeMedia(); + } catch (IOException e) { + if (errorExtractor.itemNotFound(e)) { + throw GoogleCloudStorageExceptions + .getFileNotFoundException(bucketName, objectName); + } else if (errorExtractor.rangeNotSatisfiable(e) + && newPosition == 0 + && size == -1) { + // We don't know the size yet (size == -1) and we're seeking to byte 0, but got 'range + // not satisfiable'; the object must be empty. + LOG.info("Got 'range not satisfiable' for reading {} at position 0; assuming empty.", + StorageResourceId.createReadableString(bucketName, objectName)); + size = 0; + return new ByteArrayInputStream(new byte[0]); + } else { + String msg = String.format("Error reading %s at position %d", + StorageResourceId.createReadableString(bucketName, objectName), newPosition); + throw new IOException(msg, e); + } + } + + String contentRange = response.getHeaders().getContentRange(); + if (response.getHeaders().getContentLength() != null) { + size = response.getHeaders().getContentLength() + newPosition; + } else if (contentRange != null) { + String sizeStr = SLASH.split(contentRange)[1]; + try { + size = Long.parseLong(sizeStr); + } catch (NumberFormatException e) { + throw new IOException( + "Could not determine size from response from Content-Range: " + contentRange, e); + } + } else { + throw new IOException("Could not determine size of response"); + } + return response.getContent(); + } + + /** + * Throws if this channel is not currently open. + */ + private void throwIfNotOpen() + throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageWriteChannel.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageWriteChannel.java new file mode 100644 index 000000000000..11113d0367ea --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/GoogleCloudStorageWriteChannel.java @@ -0,0 +1,379 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.InputStreamContent; +import com.google.api.client.util.Preconditions; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.model.StorageObject; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; + +/** + * Implements WritableByteChannel to provide write access to GCS. + */ +public class GoogleCloudStorageWriteChannel + implements WritableByteChannel { + + // The minimum logging interval for upload progress. + private static final long MIN_LOGGING_INTERVAL_MS = 60000L; + + // Logger. + private static final Logger LOG = LoggerFactory.getLogger(GoogleCloudStorageWriteChannel.class); + + // Buffering used in the upload path: + // There are a series of buffers used along the upload path. It is important to understand their + // function before tweaking their values. + // + // Note: Most values are already tweaked based on performance measurements. If you want to change + // buffer sizes, you should change only 1 buffer size at a time to make sure you understand + // the correlation between various buffers and their characteristics. + // + // Upload path: + // Uploading a file involves the following steps: + // -- caller creates a write stream. It involves creating a pipe between data writer (controlled + // by the caller) and data uploader. + // The writer and the uploader are on separate threads. That is, pipe operation is asynchronous + // between its + // two ends. + // -- caller puts data in a ByteBuffer and calls write(ByteBuffer). The write() method starts + // writing into sink end of the pipe. It blocks if pipe buffer is full till the other end + // reads data to make space. + // -- MediaHttpUploader code keeps on reading from the source end of the pipe till it has + // uploadBufferSize amount of data. + // + // The following buffers are involved along the above path: + // -- ByteBuffer passed by caller. We have no control over its size. + // + // -- Pipe buffer. + // size = UPLOAD_PIPE_BUFFER_SIZE_DEFAULT (1 MB) + // Increasing size does not have noticeable difference on performance. + // + // -- Buffer used by Java client + // code. + // size = UPLOAD_CHUNK_SIZE_DEFAULT (64 MB) + + // A pipe that connects write channel used by caller to the input stream used by GCS uploader. + // The uploader reads from input stream which blocks till a caller writes some data to the + // write channel (pipeSinkChannel below). The pipe is formed by connecting pipeSink to pipeSource. + private PipedOutputStream pipeSink; + private PipedInputStream pipeSource; + + // Size of buffer used by upload pipe. + private int pipeBufferSize = UPLOAD_PIPE_BUFFER_SIZE_DEFAULT; + + // A channel wrapper over pipeSink. + private WritableByteChannel pipeSinkChannel; + + // Upload operation that takes place on a separate thread. + private UploadOperation uploadOperation; + + // Default GCS upload granularity. + private static final int GCS_UPLOAD_GRANULARITY = 8 * 1024 * 1024; + + // Upper limit on object size. + // We use less than 250GB limit to avoid potential boundary errors + // in scotty/blobstore stack. + private static final long UPLOAD_MAX_SIZE = 249 * 1024 * 1024 * 1024L; + + // Chunk size to use. Limit the amount of memory used in low memory + // environments such as small AppEngine instances. + private static final int UPLOAD_CHUNK_SIZE_DEFAULT = + Runtime.getRuntime().totalMemory() < 512 * 1024 * 1024 + ? GCS_UPLOAD_GRANULARITY : 8 * GCS_UPLOAD_GRANULARITY; + + // If true, we get very high write throughput but writing files larger than UPLOAD_MAX_SIZE + // will not succeed. Set it to false to allow larger files at lower throughput. + private static boolean limitFileSizeTo250Gb = true; + + // Chunk size to use. + static int uploadBufferSize = UPLOAD_CHUNK_SIZE_DEFAULT; + + // Default size of upload buffer. + public static final int UPLOAD_PIPE_BUFFER_SIZE_DEFAULT = 1 * 1024 * 1024; + + // ClientRequestHelper to be used instead of calling final methods in client requests. + private static ClientRequestHelper clientRequestHelper = new ClientRequestHelper(); + + /** + * Allows running upload operation on a background thread. + */ + static class UploadOperation + implements Runnable { + + // Object to be uploaded. This object declared final for safe object publishing. + private final Storage.Objects.Insert insertObject; + + // Exception encountered during upload. + Throwable exception; + + // Allows other threads to wait for this operation to be complete. This object declared final + // for safe object publishing. + final CountDownLatch uploadDone = new CountDownLatch(1); + + // Read end of the pipe. This object declared final for safe object publishing. + private final InputStream pipeSource; + + /** + * Constructs an instance of UploadOperation. + * + * @param insertObject object to be uploaded + */ + public UploadOperation(Storage.Objects.Insert insertObject, InputStream pipeSource) { + this.insertObject = insertObject; + this.pipeSource = pipeSource; + } + + /** + * Gets exception/error encountered during upload or null. + */ + public Throwable exception() { + return exception; + } + + /** + * Runs the upload operation. + */ + @Override + public void run() { + try { + insertObject.execute(); + } catch (Throwable t) { + exception = t; + LOG.error("Upload failure", t); + } finally { + uploadDone.countDown(); + try { + // Close this end of the pipe so that the writer at the other end + // will not hang indefinitely. + pipeSource.close(); + } catch (IOException ioe) { + LOG.error("Error trying to close pipe.source()", ioe); + // Log and ignore IOException while trying to close the channel, + // as there is not much we can do about it. + } + } + } + + public void waitForCompletion() { + do { + try { + uploadDone.await(); + } catch (InterruptedException e) { + // Ignore it and continue to wait. + } + } while(uploadDone.getCount() > 0); + } + } + + /** + * Constructs an instance of GoogleCloudStorageWriteChannel. + * + * @param threadPool thread pool to use for running the upload operation + * @param gcs storage object instance + * @param bucketName name of the bucket to create object in + * @param objectName name of the object to create + * @throws IOException on IO error + */ + public GoogleCloudStorageWriteChannel( + ExecutorService threadPool, Storage gcs, String bucketName, + String objectName, String contentType) + throws IOException { + init(threadPool, gcs, bucketName, objectName, contentType); + } + + /** + * Sets the ClientRequestHelper to be used instead of calling final methods in client requests. + */ + static void setClientRequestHelper(ClientRequestHelper helper) { + clientRequestHelper = helper; + } + + /** + * Writes contents of the given buffer to this channel. + * + * Note: The data that one writes gets written to a pipe which may not block + * if the pipe has sufficient buffer space. A success code returned from this method + * does not mean that the specific data was successfully written to the underlying + * storage. It simply means that there is no error at present. The data upload + * may encounter an error on a separate thread. Such error is not ignored; + * it shows up as an exception during a subsequent call to write() or close(). + * The only way to be sure of successful upload is when the close() method + * returns successfully. + * + * @param buffer buffer to write + * @throws IOException on IO error + */ + @Override + public int write(ByteBuffer buffer) + throws IOException { + throwIfNotOpen(); + + // No point in writing further if upload failed on another thread. + throwIfUploadFailed(); + + return pipeSinkChannel.write(buffer); + } + + /** + * Tells whether this channel is open. + * + * @return a value indicating whether this channel is open + */ + @Override + public boolean isOpen() { + return (pipeSinkChannel != null) && pipeSinkChannel.isOpen(); + } + + /** + * Closes this channel. + * + * Note: + * The method returns only after all data has been successfully written to GCS + * or if there is a non-retry-able error. + * + * @throws IOException on IO error + */ + @Override + public void close() + throws IOException { + throwIfNotOpen(); + try { + pipeSinkChannel.close(); + uploadOperation.waitForCompletion(); + throwIfUploadFailed(); + } finally { + pipeSinkChannel = null; + pipeSink = null; + pipeSource = null; + uploadOperation = null; + } + } + + /** + * Sets size of upload buffer used. + */ + public static void setUploadBufferSize(int bufferSize) { + Preconditions.checkArgument(bufferSize > 0, + "Upload buffer size must be great than 0."); + if (bufferSize % GCS_UPLOAD_GRANULARITY != 0) { + LOG.warn("Upload buffer size should be a multiple of {} for best performance, got {}", + GCS_UPLOAD_GRANULARITY, bufferSize); + } + GoogleCloudStorageWriteChannel.uploadBufferSize = bufferSize; + } + + /** + * Enables or disables hard limit of 250GB on size of uploaded files. + * + * If enabled, we get very high write throughput but writing files larger than UPLOAD_MAX_SIZE + * will not succeed. Set it to false to allow larger files at lower throughput. + */ + public static void enableFileSizeLimit250Gb(boolean enableLimit) { + GoogleCloudStorageWriteChannel.limitFileSizeTo250Gb = enableLimit; + } + + /** + * Initializes an instance of GoogleCloudStorageWriteChannel. + * + * @param threadPool thread pool to use for running the upload operation + * @param gcs storage object instance + * @param bucketName name of the bucket in which to create object + * @param objectName name of the object to create + * @throws IOException on IO error + */ + private void init( + ExecutorService threadPool, Storage gcs, String bucketName, + String objectName, String contentType) + throws IOException { + + // Create object with the given name. + StorageObject object = (new StorageObject()).setName(objectName); + + // Create a pipe such that its one end is connected to the input stream used by + // the uploader and the other end is the write channel used by the caller. + pipeSource = new PipedInputStream(pipeBufferSize); + pipeSink = new PipedOutputStream(pipeSource); + pipeSinkChannel = Channels.newChannel(pipeSink); + + // Connect pipe-source to the stream used by uploader. + InputStreamContent objectContentStream = + new InputStreamContent(contentType, pipeSource); + // Indicate that we do not know length of file in advance. + objectContentStream.setLength(-1); + objectContentStream.setCloseInputStream(false); + Storage.Objects.Insert insertObject = + gcs.objects().insert(bucketName, object, objectContentStream); + insertObject.setDisableGZipContent(true); + insertObject.getMediaHttpUploader().setProgressListener( + new LoggingMediaHttpUploaderProgressListener(objectName, MIN_LOGGING_INTERVAL_MS)); + + // Insert necessary http headers to enable 250GB limit+high throughput if so configured. + if (limitFileSizeTo250Gb) { + HttpHeaders headers = clientRequestHelper.getRequestHeaders(insertObject); + headers.set("X-Goog-Upload-Desired-Chunk-Granularity", GCS_UPLOAD_GRANULARITY); + headers.set("X-Goog-Upload-Max-Raw-Size", UPLOAD_MAX_SIZE); + } + // Change chunk size from default value (10MB) to one that yields higher performance. + clientRequestHelper.setChunkSize(insertObject, uploadBufferSize); + + // Given that the two ends of the pipe must operate asynchronous relative + // to each other, we need to start the upload operation on a separate thread. + uploadOperation = new UploadOperation(insertObject, pipeSource); + threadPool.execute(uploadOperation); + } + + /** + * Throws if this channel is not currently open. + * + * @throws IOException on IO error + */ + private void throwIfNotOpen() + throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + } + + /** + * Throws if upload operation failed. Propagates any errors. + * + * @throws IOException on IO error + */ + private void throwIfUploadFailed() + throws IOException { + if ((uploadOperation != null) && (uploadOperation.exception() != null)) { + if (uploadOperation.exception() instanceof Error) { + throw (Error) uploadOperation.exception(); + } + throw new IOException(uploadOperation.exception()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListener.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListener.java new file mode 100644 index 000000000000..c215f4aeafaf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListener.java @@ -0,0 +1,91 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.googleapis.media.MediaHttpUploader; +import com.google.api.client.googleapis.media.MediaHttpUploader.UploadState; +import com.google.api.client.googleapis.media.MediaHttpUploaderProgressListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Logs the status of uploads. At the beginning, during, and + * at the end of the upload, emits relevant statistics such as how many bytes + * uploaded and the rate at which the upload is progressing. + *

+ * A new instance of this progress listener should be used for each MediaHttpUploader. + */ +class LoggingMediaHttpUploaderProgressListener implements MediaHttpUploaderProgressListener { + private static final Logger LOG = + LoggerFactory.getLogger(MediaHttpUploaderProgressListener.class); + private static final double BYTES_IN_MB = 1048576.0; + private final long minLoggingInterval; + private final String name; + private long startTime; + private long prevTime; + private long prevUploadedBytes; + + /** + * Creates a upload progress listener which emits relevant statistics about the + * progress of the upload. + * @param name The name of the resource being uploaded. + * @param minLoggingInterval The minimum amount of time (millis) between logging upload progress. + */ + LoggingMediaHttpUploaderProgressListener(String name, long minLoggingInterval) { + this.name = name; + this.minLoggingInterval = minLoggingInterval; + } + + @Override + public void progressChanged(MediaHttpUploader uploader) throws IOException { + progressChanged(LOG, + uploader.getUploadState(), + uploader.getNumBytesUploaded(), + System.currentTimeMillis()); + } + + void progressChanged(Logger log, UploadState uploadState, long bytesUploaded, long currentTime) { + switch (uploadState) { + case INITIATION_STARTED: + startTime = currentTime; + prevTime = currentTime; + log.info("Uploading: {}", name); + break; + case MEDIA_IN_PROGRESS: + // Limit messages to be emitted for in progress uploads. + if (currentTime > prevTime + minLoggingInterval) { + double averageRate = (bytesUploaded / BYTES_IN_MB) + / ((currentTime - startTime) / 1000.0); + double currentRate = ((bytesUploaded - prevUploadedBytes) / BYTES_IN_MB) + / ((currentTime - prevTime) / 1000.0); + log.info(String.format( + "Uploading: %s Average Rate: %.3f MiB/s, Current Rate: %.3f MiB/s, Total: %.3f MiB", + name, averageRate, currentRate, bytesUploaded / BYTES_IN_MB)); + prevTime = currentTime; + prevUploadedBytes = bytesUploaded; + } + break; + case MEDIA_COMPLETE: + log.info("Finished Uploading: {}", name); + break; + default: + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/StorageResourceId.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/StorageResourceId.java new file mode 100644 index 000000000000..b6051a5147d3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsio/StorageResourceId.java @@ -0,0 +1,165 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import com.google.api.client.util.Preconditions; +import com.google.common.base.Strings; + +import java.util.Objects; + +/** + * Data struct representing either a GCS StorageObject, a GCS Bucket or the GCS root (gs://). + * If both bucketName and objectName are null, the StorageResourceId refers to GCS root (gs://). + * If bucketName is non-null, and objectName is null, then this refers to a GCS Bucket. Otherwise, + * if bucketName and objectName are both non-null, this refers to a GCS StorageObject. + */ +public class StorageResourceId { + // The singleton instance identifying the GCS root (gs://). Both getObjectName() and + // getBucketName() will return null. + public static final StorageResourceId ROOT = new StorageResourceId(); + + // Bucket name of this storage resource to be used with the Google Cloud Storage API. + private final String bucketName; + + // Object name of this storage resource to be used with the Google Cloud Storage API. + private final String objectName; + + // Human-readable String to be returned by toString(); kept as 'final' member for efficiency. + private final String readableString; + + /** + * Constructor for a StorageResourceId which refers to the GCS root (gs://). Private because + * all external users should just use the singleton StorageResourceId.ROOT. + */ + private StorageResourceId() { + this.bucketName = null; + this.objectName = null; + this.readableString = createReadableString(bucketName, objectName); + } + + /** + * Constructor for a StorageResourceId representing a Bucket; {@code getObjectName()} will return + * null for a StorageResourceId which represents a Bucket. + * + * @param bucketName The bucket name of the resource. Must be non-empty and non-null. + */ + public StorageResourceId(String bucketName) { + Preconditions.checkArgument(!Strings.isNullOrEmpty(bucketName), + "bucketName must not be null or empty"); + + this.bucketName = bucketName; + this.objectName = null; + this.readableString = createReadableString(bucketName, objectName); + } + + /** + * Constructor for a StorageResourceId representing a full StorageObject, including bucketName + * and objectName. + * + * @param bucketName The bucket name of the resource. Must be non-empty and non-null. + * @param objectName The object name of the resource. Must be non-empty and non-null. + */ + public StorageResourceId(String bucketName, String objectName) { + Preconditions.checkArgument(!Strings.isNullOrEmpty(bucketName), + "bucketName must not be null or empty"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(objectName), + "objectName must not be null or empty"); + + this.bucketName = bucketName; + this.objectName = objectName; + this.readableString = createReadableString(bucketName, objectName); + } + + /** + * Returns true if this StorageResourceId represents a GCS StorageObject; if true, both + * {@code getBucketName} and {@code getObjectName} will be non-empty and non-null. + */ + public boolean isStorageObject() { + return bucketName != null && objectName != null; + } + + /** + * Returns true if this StorageResourceId represents a GCS Bucket; if true, then {@code + * getObjectName} will return null. + */ + public boolean isBucket() { + return bucketName != null && objectName == null; + } + + /** + * Returns true if this StorageResourceId represents the GCS root (gs://); if true, then + * both {@code getBucketName} and {@code getObjectName} will be null. + */ + public boolean isRoot() { + return bucketName == null && objectName == null; + } + + /** + * Gets the bucket name component of this resource identifier. + */ + public String getBucketName() { + return bucketName; + } + + /** + * Gets the object name component of this resource identifier. + */ + public String getObjectName() { + return objectName; + } + + /** + * Returns a string of the form gs:///. + */ + @Override + public String toString() { + return readableString; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof StorageResourceId) { + StorageResourceId other = (StorageResourceId) obj; + return Objects.equals(bucketName, other.bucketName) + && Objects.equals(objectName, other.objectName); + } + return false; + } + + @Override + public int hashCode() { + return readableString.hashCode(); + } + + /** + * Helper for standardizing the way various human-readable messages in logs/exceptions which refer + * to a bucket/object pair. + */ + public static String createReadableString(String bucketName, String objectName) { + if (bucketName == null && objectName == null) { + // TODO: Unify this method with other methods which convert bucketName/objectName + // to a URI; maybe use the single slash for compatibility. + return "gs://"; + } else if (bucketName != null && objectName == null) { + return String.format("gs://%s", bucketName); + } else if (bucketName != null && objectName != null) { + return String.format("gs://%s/%s", bucketName, objectName); + } + throw new IllegalArgumentException( + String.format("Invalid bucketName/objectName pair: gs://%s/%s", bucketName, objectName)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java new file mode 100644 index 000000000000..98fdc44113a3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used by the Dataflow SDK. **/ +package com.google.cloud.dataflow.sdk.util; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTag.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTag.java new file mode 100644 index 000000000000..3caed1a8bcce --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTag.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +/** + * A {@link TupleTag} combined with the {@link Coder} to use for + * values associated with the tag. + * + *

Used as tags in + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.KeyedState}. + * + * @param the type of the values associated with this tag + */ +public class CodedTupleTag extends TupleTag { + /** + * Returns a {@code CodedTupleTag} with the given id which uses the + * given {@code Coder} whenever a value associated with the tag + * needs to be serialized. + * + *

It is up to the user to ensure that two + * {@code CodedTupleTag}s with the same id actually mean the same + * tag and carry the same generic type parameter. Violating this + * invariant can lead to hard-to-diagnose runtime type errors. + * + *

(An explicit id is required so that persistent keyed state + * saved by one run of a streaming program can be reused if that + * streaming program is upgraded to a new version.) + * + * @param the type of the values associated with the tag + */ + public static CodedTupleTag of(String id, Coder coder) { + return new CodedTupleTag(id, coder); + } + + /** + * Returns the {@code Coder} used for values associated with this tag. + */ + public Coder getCoder() { + return coder; + } + + + /////////////////////////////////////////////// + + private final Coder coder; + + CodedTupleTag(String id, Coder coder) { + super(id); + this.coder = coder; + } + + @Override + public String toString() { + return "CodedTupleTag<" + getId() + ", " + coder + ">"; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTagMap.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTagMap.java new file mode 100644 index 000000000000..6f96c694ea2e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/CodedTupleTagMap.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import java.util.Map; + +/** + * A mapping of {@link CodedTupleTag}s to associated values. + * + *

Returned by + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.KeyedState#lookup(java.util.List)}. + */ +public class CodedTupleTagMap { + /** + * Returns a {@code CodedTupleTagMap} containing the given mappings. + * + *

It is up to the caller to ensure that the value associated + * with each CodedTupleTag in the map has the static type specified + * by that tag. + * + *

Intended for internal use only. + */ + public static CodedTupleTagMap of(Map, Object> map) { + // TODO: Should we copy the Map here, to insulate this + // map from any changes to the original argument? + return new CodedTupleTagMap(map); + } + + /** + * Returns the value associated with the given tag in this + * {@code CodedTupleTagMap}, or {@code null} if the tag has no + * asssociated value. + */ + public T get(CodedTupleTag tag) { + return (T) map.get(tag); + } + + ////////////////////////////////////////////// + + private Map, Object> map; + + CodedTupleTagMap(Map, Object> map) { + this.map = map; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java new file mode 100644 index 000000000000..d354707ebb0c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * An immutable key/value pair. + * + *

Various + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s like + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} and + * {@link com.google.cloud.dataflow.sdk.transforms.Combine#perKey} + * work on {@link PCollection}s of KVs. + * + * @param the type of the key + * @param the type of the value + */ +public class KV implements Serializable { + /** Returns a KV with the given key and value. */ + public static KV of(K key, V value) { + return new KV<>(key, value); + } + + /** Returns the key of this KV. */ + public K getKey() { + return key; + } + + /** Returns the value of this KV. */ + public V getValue() { + return value; + } + + + ///////////////////////////////////////////////////////////////////////////// + + final K key; + final V value; + + private KV(K key, V value) { + this.key = key; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof KV) { + KV that = (KV) o; + return (this.key == null ? that.key == null + : this.key.equals(that.key)) + && (this.value == null ? that.value == null + : this.value.equals(that.value)); + } + return false; + } + + /** Orders the KV by the key. A null key is less than any non-null key. */ + public static class OrderByKey, V> implements + Comparator>, Serializable { + @Override + public int compare(KV a, KV b) { + if (a.key == null) { + return b.key == null ? 0 : -1; + } else if (b.key == null) { + return 1; + } else { + return a.key.compareTo(b.key); + } + } + } + + /** Orders the KV by the value. A null value is less than any non-null value. */ + public static class OrderByValue> + implements Comparator>, Serializable { + @Override + public int compare(KV a, KV b) { + if (a.value == null) { + return b.value == null ? 0 : -1; + } else if (b.value == null) { + return 1; + } else { + return a.value.compareTo(b.value); + } + } + } + + @Override + public int hashCode() { + return getClass().hashCode() + + (key == null ? 0 : key.hashCode()) + + (value == null ? 0 : value.hashCode()); + } + + @Override + public String toString() { + return "KV(" + key + ", " + value + ")"; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java new file mode 100644 index 000000000000..fc3f179fc176 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; +import java.util.Collections; + +/** + * {@code PBegin} is used as the "input" to a root + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform} which + * is the first operation in a {@link Pipeline}, such as + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read} or + * {@link com.google.cloud.dataflow.sdk.transforms.Create}. + * + *

Typically created by calling {@link Pipeline#begin} on a Pipeline. + */ +public class PBegin implements PInput { + /** + * Returns a {@code PBegin} in the given {@code Pipeline}. + */ + public static PBegin in(Pipeline pipeline) { + return new PBegin(pipeline); + } + + /** + * Applies the given PTransform to this input PBegin, and + * returns the PTransform's Output. + */ + public Output apply( + PTransform t) { + return Pipeline.applyTransform(this, t); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + // A PBegin contains no PValues. + return Collections.emptyList(); + } + + @Override + public void finishSpecifying() { + // Nothing more to be done. + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Constructs a {@code PBegin} in the given {@code Pipeline}. + */ + protected PBegin(Pipeline pipeline) { + this.pipeline = pipeline; + } + + private Pipeline pipeline; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java new file mode 100644 index 000000000000..fc4b0886b7d5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java @@ -0,0 +1,240 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.common.reflect.TypeToken; + +/** + * A {@code PCollection} is an immutable collection of values of type + * {@code T}. A {@code PCollection} can contain either a bounded or unbounded + * number of elements. Bounded and unbounded {@code PCollection}s are produced + * as the output of {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s + * (including root PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read}, + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO.Read} and + * {@link com.google.cloud.dataflow.sdk.transforms.Create}), and can + * be passed as the inputs of other PTransforms. + * + *

Some root transforms produce bounded {@code PCollections} and others + * produce unbounded ones. For example, + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read} reads a static set + * of files, so it produces a bounded {@code PCollection}. + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO.Read}, on the other hand, + * receives a potentially infinite stream of Pubsub messages, so it produces + * an unbounded {@code PCollection}. + * + *

Each element in a {@code PCollection} may have an associated implicit + * timestamp. Sources assign timestamps to elements when they create + * {@code PCollection}s, and other {@code PTransform}s propagate these + * timestamps from their input to their output. For example, PubsubIO.Read + * assigns pubsub message timestamps to elements, and TextIO.Read assigns + * the default value {@code Long.MIN_VALUE} to elements. User code can + * explicitly assign timestamps to elements with + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#outputWithTimestamp}. + * + *

Additionally, a {@code PCollection} has an associated + * {@link WindowingFn} and each element is assigned to a set of windows. + * By default, the windowing function is + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow} + * and all elements are assigned into a single default window. + * This default can be overridden with the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * {@code PTransform}. Dataflow pipelines run in classic batch MapReduce style + * with the default GlobalWindow strategy if timestamps are ignored. + * + *

See the individual {@code PTransform} subclasses for specific information + * on how they propagate timestamps and windowing. + * + * @param the type of the elements of this PCollection + */ +public class PCollection extends TypedPValue { + /** + * Returns the name of this PCollection. + * + *

By default, the name of a PCollection is based on the name of the + * PTransform that produces it. It can be specified explicitly by + * calling {@link #setName}. + * + * @throws IllegalStateException if the name hasn't been set yet + */ + @Override + public String getName() { + return super.getName(); + } + + /** + * Sets the name of this PCollection. Returns {@code this}. + * + * @throws IllegalStateException if this PCollection has already been + * finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + @Override + public PCollection setName(String name) { + super.setName(name); + return this; + } + + /** + * Returns the Coder used by this PCollection to encode and decode + * the values stored in it. + * + * @throws IllegalStateException if the Coder hasn't been set, and + * couldn't be inferred + */ + @Override + public Coder getCoder() { + return super.getCoder(); + } + + /** + * Sets the Coder used by this PCollection to encode and decode the + * values stored in it. Returns {@code this}. + * + * @throws IllegalStateException if this PCollection has already + * been finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + @Override + public PCollection setCoder(Coder coder) { + super.setCoder(coder); + return this; + } + + /** + * Returns whether or not the elements of this PCollection have a + * well-defined and fixed order, such that subsequent reading of the + * PCollection is guaranteed to process the elements in order. + * + *

Requiring a fixed order can limit optimization opportunities. + * + *

By default, PCollections do not have a well-defined or fixed order. + */ + public boolean isOrdered() { + return isOrdered; + } + + /** + * Sets whether or not this PCollection should preserve the order in + * which elements are put in it, such that subsequent parallel + * reading of the PCollection is guaranteed to process the elements + * in order. + * + *

Requiring a fixed order can limit optimization opportunities. + * + *

Returns {@code this}. + * + * @throws IllegalStateException if this PCollection has already + * been finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + public PCollection setOrdered(boolean isOrdered) { + if (this.isOrdered != isOrdered) { + if (isFinishedSpecifyingInternal()) { + throw new IllegalStateException( + "cannot change the orderedness of " + this + + " once it's been used"); + } + this.isOrdered = isOrdered; + } + return this; + } + + /** + * Applies the given PTransform to this input PCollection, and + * returns the PTransform's Output. + */ + public Output apply( + PTransform, Output> t) { + return Pipeline.applyTransform(this, t); + } + + /** + * Returns the {@link WindowingFn} of this {@code PCollection}. + */ + public WindowingFn getWindowingFn() { + return windowingFn; + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + /** + * Whether or not the elements of this PCollection have a + * well-defined and fixed order, such that subsequent reading of the + * PCollection is guaranteed to process the elements in order. + */ + private boolean isOrdered = false; + + /** + * {@link WindowingFn} that will be used to merge windows in + * this {@code PCollection} and subsequent {@code PCollection}s produced + * from this one. + * + *

By default, no merging is performed. + */ + private WindowingFn windowingFn; + + private PCollection() {} + + /** + * Sets the {@code TypeToken} for this {@code PCollection}, so that + * the enclosing {@code PCollectionTuple}, {@code PCollectionList}, + * or {@code PTransform>}, etc., can provide + * more detailed reflective information. + */ + @Override + public PCollection setTypeTokenInternal(TypeToken typeToken) { + super.setTypeTokenInternal(typeToken); + return this; + } + + /** + * Sets the {@link WindowingFn} of this {@code PCollection}. + * + *

For use by primitive transformations only. + */ + public PCollection setWindowingFnInternal(WindowingFn windowingFn) { + this.windowingFn = windowingFn; + return this; + } + + /** + * Sets the {@link Pipeline} for this {@code PCollection}. + * + *

For use by primitive transformations only. + */ + @Override + public PCollection setPipelineInternal(Pipeline pipeline) { + super.setPipelineInternal(pipeline); + return this; + } + + /** + * Creates and returns a new PCollection for a primitive output. + * + *

For use by primitive transformations only. + */ + public static PCollection createPrimitiveOutputInternal( + WindowingFn windowingFn) { + return new PCollection().setWindowingFnInternal(windowingFn); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java new file mode 100644 index 000000000000..26b7300a9341 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java @@ -0,0 +1,227 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * A {@code PCollectionList} is an immutable list of homogeneously + * typed {@code PCollection}s. A PCollectionList is used, for + * instance, as the input to + * {@link com.google.cloud.dataflow.sdk.transforms.Flatten} or the + * output of + * {@link com.google.cloud.dataflow.sdk.transforms.Partition}. + * + *

PCollectionLists can be created and accessed like follows: + *

 {@code
+ * PCollection pc1 = ...;
+ * PCollection pc2 = ...;
+ * PCollection pc3 = ...;
+ *
+ * // Create a PCollectionList with three PCollections:
+ * PCollectionList pcs = PCollectionList.of(pc1).and(pc2).and(pc3);
+ *
+ * // Create an empty PCollectionList:
+ * Pipeline p = ...;
+ * PCollectionList pcs2 = PCollectionList.empty(p);
+ *
+ * // Get PCollections out of a PCollectionList, by index (origin 0):
+ * PCollection pcX = pcs.get(1);
+ * PCollection pcY = pcs.get(0);
+ * PCollection pcZ = pcs.get(2);
+ *
+ * // Get a list of all PCollections in a PCollectionList:
+ * List> allPcs = pcs.getAll();
+ * } 
+ * + * @param the type of the elements of all the PCollections in this list + */ +public class PCollectionList implements PInput, POutput { + /** + * Returns an empty PCollectionList that is part of the given Pipeline. + * + *

Longer PCollectionLists can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList empty(Pipeline pipeline) { + return new PCollectionList<>(pipeline); + } + + /** + * Returns a singleton PCollectionList containing the given PCollection. + * + *

Longer PCollectionLists can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList of(PCollection pc) { + return new PCollectionList(pc.getPipeline()).and(pc); + } + + /** + * Returns a PCollectionList containing the given PCollections, in order. + * + *

The argument list cannot be empty. + * + *

All the PCollections in the resulting PCollectionList must be + * part of the same Pipeline. + * + *

Longer PCollectionLists can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList of(Iterable> pcs) { + Iterator> pcsIter = pcs.iterator(); + if (!pcsIter.hasNext()) { + throw new IllegalArgumentException( + "must either have a non-empty list of PCollections, " + + "or must first call empty(Pipeline)"); + } + return new PCollectionList(pcsIter.next().getPipeline()).and(pcs); + } + + /** + * Returns a new PCollectionList that has all the PCollections of + * this PCollectionList plus the given PCollection appended to the end. + * + *

All the PCollections in the resulting PCollectionList must be + * part of the same Pipeline. + */ + public PCollectionList and(PCollection pc) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + return new PCollectionList<>(pipeline, + new ImmutableList.Builder>() + .addAll(pcollections) + .add(pc) + .build()); + } + + /** + * Returns a new PCollectionList that has all the PCollections of + * this PCollectionList plus the given PCollections appended to the end, + * in order. + * + *

All the PCollections in the resulting PCollectionList must be + * part of the same Pipeline. + */ + public PCollectionList and(Iterable> pcs) { + List> copy = new ArrayList<>(pcollections); + for (PCollection pc : pcs) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + copy.add(pc); + } + return new PCollectionList<>(pipeline, copy); + } + + /** + * Returns the number of PCollections in this PCollectionList. + */ + public int size() { + return pcollections.size(); + } + + /** + * Returns the PCollection at the given index (origin zero). Throws + * IndexOutOfBounds if the index is out of the range + * {@code [0..size()-1]}. + */ + public PCollection get(int index) { + return pcollections.get(index); + } + + /** + * Returns an immutable List of all the PCollections in this PCollectionList. + */ + public List> getAll() { + return pcollections; + } + + /** + * Applies the given PTransform to this input {@code PCollectionList}, + * and returns the PTransform's Output. + */ + public Output apply( + PTransform, Output> t) { + return Pipeline.applyTransform(this, t); + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + final Pipeline pipeline; + final List> pcollections; + + PCollectionList(Pipeline pipeline) { + this(pipeline, new ArrayList>()); + } + + PCollectionList(Pipeline pipeline, List> pcollections) { + this.pipeline = pipeline; + this.pcollections = Collections.unmodifiableList(pcollections); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + return pcollections; + } + + @Override + public void recordAsOutput(Pipeline pipeline, + PTransform transform) { + if (this.pipeline != null && this.pipeline != pipeline) { + throw new AssertionError( + "not expecting to change the Pipeline owning a PCollectionList"); + } + int i = 0; + for (PCollection pc : pcollections) { + pc.recordAsOutput(pipeline, transform, "out" + i); + i++; + } + } + + @Override + public void finishSpecifying() { + for (PCollection pc : pcollections) { + pc.finishSpecifying(); + } + } + + @Override + public void finishSpecifyingOutput() { + for (PCollection pc : pcollections) { + pc.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java new file mode 100644 index 000000000000..fecc175f4d3c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.common.collect.ImmutableMap; +import com.google.common.reflect.TypeToken; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A {@code PCollectionTuple} is an immutable tuple of + * heterogeneously-typed {@link PCollection}s, "keyed" by + * {@link TupleTag}s. A PCollectionTuple can be used as the input or + * output of a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform} taking + * or producing multiple PCollection inputs or outputs that can be of + * different types, for instance a + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} with side + * outputs. + * + *

PCollectionTuples can be created and accessed like follows: + *

 {@code
+ * PCollection pc1 = ...;
+ * PCollection pc2 = ...;
+ * PCollection> pc3 = ...;
+ *
+ * // Create TupleTags for each of the PCollections to put in the
+ * // PCollectionTuple (the type of the TupleTag enables tracking the
+ * // static type of each of the PCollections in the PCollectionTuple):
+ * TupleTag tag1 = new TupleTag<>();
+ * TupleTag tag2 = new TupleTag<>();
+ * TupleTag> tag3 = new TupleTag<>();
+ *
+ * // Create a PCollectionTuple with three PCollections:
+ * PCollectionTuple pcs =
+ *     PCollectionTuple.of(tag1, pc1)
+ *                     .and(tag2, pc2)
+ *                     .and(tag3, pc3);
+ *
+ * // Create an empty PCollectionTuple:
+ * Pipeline p = ...;
+ * PCollectionTuple pcs2 = PCollectionTuple.empty(p);
+ *
+ * // Get PCollections out of a PCollectionTuple, using the same tags
+ * // that were used to put them in:
+ * PCollection pcX = pcs.get(tag2);
+ * PCollection pcY = pcs.get(tag1);
+ * PCollection> pcZ = pcs.get(tag3);
+ *
+ * // Get a map of all PCollections in a PCollectionTuple:
+ * Map, PCollection> allPcs = pcs.getAll();
+ * } 
+ */ +public class PCollectionTuple implements PInput, POutput { + /** + * Returns an empty PCollectionTuple that is part of the given Pipeline. + * + *

Longer PCollectionTuples can be created by calling + * {@link #and} on the result. + */ + public static PCollectionTuple empty(Pipeline pipeline) { + return new PCollectionTuple(pipeline); + } + + /** + * Returns a singleton PCollectionTuple containing the given + * PCollection keyed by the given TupleTag. + * + *

Longer PCollectionTuples can be created by calling + * {@link #and} on the result. + */ + public static PCollectionTuple of(TupleTag tag, PCollection pc) { + return empty(pc.getPipeline()).and(tag, pc); + } + + /** + * Returns a new PCollectionTuple that has all the PCollections and + * tags of this PCollectionTuple plus the given PCollection and tag. + * + *

The given TupleTag should not already be mapped to a + * PCollection in this PCollectionTuple. + * + *

All the PCollections in the resulting PCollectionTuple must be + * part of the same Pipeline. + */ + public PCollectionTuple and(TupleTag tag, PCollection pc) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + + // The TypeToken in tag will often have good + // reflective information about T + pc.setTypeTokenInternal(tag.getTypeToken()); + return new PCollectionTuple(pipeline, + new ImmutableMap.Builder, PCollection>() + .putAll(pcollectionMap) + .put(tag, pc) + .build()); + } + + /** + * Returns whether this PCollectionTuple contains a PCollection with + * the given tag. + */ + public boolean has(TupleTag tag) { + return pcollectionMap.containsKey(tag); + } + + /** + * Returns the PCollection with the given tag in this + * PCollectionTuple. Throws IllegalArgumentException if there is no + * such PCollection, i.e., {@code !has(tag)}. + */ + public PCollection get(TupleTag tag) { + @SuppressWarnings("unchecked") + PCollection pcollection = (PCollection) pcollectionMap.get(tag); + if (pcollection == null) { + throw new IllegalArgumentException( + "TupleTag not found in this PCollectionTuple tuple"); + } + return pcollection; + } + + /** + * Returns an immutable Map from TupleTag to corresponding + * PCollection, for all the members of this PCollectionTuple. + */ + public Map, PCollection> getAll() { + return pcollectionMap; + } + + /** + * Applies the given PTransform to this input PCollectionTuple, and + * returns the PTransform's Output. + */ + public Output apply( + PTransform t) { + return Pipeline.applyTransform(this, t); + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + Pipeline pipeline; + final Map, PCollection> pcollectionMap; + + PCollectionTuple(Pipeline pipeline) { + this(pipeline, new LinkedHashMap, PCollection>()); + } + + PCollectionTuple(Pipeline pipeline, + Map, PCollection> pcollectionMap) { + this.pipeline = pipeline; + this.pcollectionMap = Collections.unmodifiableMap(pcollectionMap); + } + + /** + * Returns a PCollectionTuple with each of the given tags mapping to a new + * output PCollection. + * + *

For use by primitive transformations only. + */ + public static PCollectionTuple ofPrimitiveOutputsInternal( + TupleTagList outputTags, WindowingFn windowingFn) { + Map, PCollection> pcollectionMap = new LinkedHashMap<>(); + for (TupleTag outputTag : outputTags.tupleTags) { + if (pcollectionMap.containsKey(outputTag)) { + throw new IllegalArgumentException( + "TupleTag already present in this tuple"); + } + + // In fact, `token` and `outputCollection` should have + // types TypeToken and PCollection for some + // unknown T. It is safe to create `outputCollection` + // with type PCollection because it has the same + // erasure as the correct type. When a transform adds + // elements to `outputCollection` they will be of type T. + @SuppressWarnings("unchecked") + TypeToken token = (TypeToken) outputTag.getTypeToken(); + PCollection outputCollection = PCollection + .createPrimitiveOutputInternal(windowingFn) + .setTypeTokenInternal(token); + + pcollectionMap.put(outputTag, outputCollection); + } + return new PCollectionTuple(null, pcollectionMap); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + return pcollectionMap.values(); + } + + @Override + public void recordAsOutput(Pipeline pipeline, + PTransform transform) { + if (this.pipeline != null && this.pipeline != pipeline) { + throw new AssertionError( + "not expecting to change the Pipeline owning a PCollectionTuple"); + } + this.pipeline = pipeline; + int i = 0; + for (Map.Entry, PCollection> entry + : pcollectionMap.entrySet()) { + TupleTag tag = entry.getKey(); + PCollection pc = entry.getValue(); + pc.recordAsOutput(pipeline, transform, tag.getOutName(i)); + i++; + } + } + + @Override + public void finishSpecifying() { + for (PCollection pc : pcollectionMap.values()) { + pc.finishSpecifying(); + } + } + + @Override + public void finishSpecifyingOutput() { + for (PCollection pc : pcollectionMap.values()) { + pc.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java new file mode 100644 index 000000000000..d19854ccc588 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +import java.io.Serializable; + +/** + * A {@code PCollectionView} is an immutable view of a + * {@link PCollection} that can be accessed e.g. as a + * side input to a {@link DoFn}. + * + *

A {@PCollectionView} should always be the output of a {@link PTransform}. It is + * the joint responsibility of this transform and each {@link PipelineRunner} to + * implement the view in a runner-specific manner. + * + * @param the type of the value(s) accessible via this {@code PCollectionView} + * @param the type of the windowed value(s) accessible via this {@code PCollectionView} + */ +public interface PCollectionView extends PValue, Serializable { + /** + * A unique identifier, for internal use. + */ + public TupleTag>> getTagInternal(); + + /** + * For internal use only. + */ + public T fromIterableInternal(Iterable> contents); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java new file mode 100644 index 000000000000..dda48fc530a8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import java.util.Collection; +import java.util.Collections; + +/** + * {@code PDone} is the output of a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform} that + * doesn't have a non-trival result, e.g., a Write. No more + * transforms can be applied to it. + */ +public class PDone extends POutputValueBase { + public PDone() {} + + @Override + public Collection expand() { + // A PDone contains no PValues. + return Collections.emptyList(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java new file mode 100644 index 000000000000..6d86fb069535 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; + +import java.util.Collection; + +/** + * The abstract interface of things that might be input to a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}. + */ +public interface PInput { + /** + * Returns the owning Pipeline of this PInput. + * + * @throws IllegalStateException if the owning Pipeline hasn't been + * set yet + */ + public Pipeline getPipeline(); + + /** + * Expands this PInput into a list of its component input PValues. + * + *

A PValue expands to itself. + * + *

A tuple or list of PValues (e.g., + * PCollectionTuple, and PCollectionList) expands to its component + * PValues. + * + *

Not intended to be invoked directly by user code. + */ + public Collection expand(); + + /** + *

After building, finalizes this PInput to make it ready for + * being used as an input to a PTransform. + * + *

Automatically invoked whenever {@code apply()} is invoked on + * this PInput, so users do not normally call this explicitly. + */ + public void finishSpecifying(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java new file mode 100644 index 000000000000..3b3264985d55 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; + +/** + * The abstract interface of things that might be output from a + * {@link PTransform}. + */ +public interface POutput { + /** + * Expands this {@code POutput} into a list of its component output + * {@code PValue}s. + * + *

A {@link PValue} expands to itself. + * + *

A tuple or list of {@code PValue}s (e.g., + * {@link PCollectionTuple}, and + * {@link PCollectionList}) expands to its component {@code PValue}s. + * + *

Not intended to be invoked directly by user code. + */ + public Collection expand(); + + /** + * Records that this {@code POutput} is an output of the given + * {@code PTransform} in the given {@code Pipeline}. + * + *

Should expand this {@code POutput} and invoke + * {@link PValue#recordAsOutput(Pipeline, + * com.google.cloud.dataflow.sdk.transforms.PTransform, + * String)} on each component output {@code PValue}. + * + *

Automatically invoked as part of applying a + * {@code PTransform}. Not to be invoked directly by user code. + */ + public void recordAsOutput(Pipeline pipeline, + PTransform transform); + + /** + * As part of finishing the producing {@code PTransform}, finalizes this + * {@code PTransform} output to make it ready for being used as an input and + * for running. + * + *

This includes ensuring that all {@code PCollection}s + * have {@code Coder}s specified or defaulted. + * + *

Automatically invoked whenever this {@code POutput} is used + * as a {@code PInput} to another {@code PTransform}, or if never + * used as a {@code PInput}, when {@link Pipeline#run} is called, so + * users do not normally call this explicitly. + */ + public void finishSpecifyingOutput(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java new file mode 100644 index 000000000000..0401393f142b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * A {@code POutputValueBase} is the abstract base class of + * {@code PTransform} outputs. + * + *

A {@code PValueBase} that adds tracking of its producing + * {@code PTransform}. + * + *

For internal use. + */ +public abstract class POutputValueBase implements POutput { + + protected POutputValueBase() { } + + /** + * Returns the {@code PTransform} that this {@code POutputValueBase} + * is an output of. + * + *

For internal use only. + */ + public PTransform getProducingTransformInternal() { + return producingTransform; + } + + /** + * Records that this {@code POutputValueBase} is an output with the + * given name of the given {@code PTransform} in the given + * {@code Pipeline}. + * + *

To be invoked only by {@link POutput#recordAsOutput} + * implementations. Not to be invoked directly by user code. + */ + public void recordAsOutput(Pipeline pipeline, + PTransform transform) { + if (producingTransform != null) { + // Already used this POutput as a PTransform output. This can + // happen if the POutput is an output of a transform within a + // composite transform, and is also the result of the composite. + // We want to record the "immediate" atomic transform producing + // this output, and ignore all later composite transforms that + // also produce this output. + // + // Pipeline.applyInternal() uses !hasProducingTransform() to + // avoid calling this operation redundantly, but + // hasProducingTransform() doesn't apply to POutputValueBases + // that aren't PValues or composites of PValues, e.g., PDone. + return; + } + producingTransform = transform; + } + + /** + * Default behavior for {@code finishSpecifyingOutput()} is + * to do nothing. Override if your {@link PValue} requires + * finalization. + */ + public void finishSpecifyingOutput() { } + + /** + * The {@code PTransform} that produces this {@code POutputValueBase}. + */ + private PTransform producingTransform; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java new file mode 100644 index 000000000000..7e45196af813 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * A {@code PValue} is the interface to values that can be + * input and output from {@link PTransform}s. + */ +public interface PValue extends POutput, PInput { + public String getName(); + + public PValue setPipelineInternal(Pipeline pipeline); + + /** + * Returns the {@code PTransform} that this {@code PValue} is an output of. + * + *

For internal use only. + */ + public PTransform getProducingTransformInternal(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java new file mode 100644 index 000000000000..25b1fd6fd9a1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java @@ -0,0 +1,190 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.StringUtils; + +import java.util.Collection; +import java.util.Collections; + +/** + * A {@code PValueBase} is an abstract base class that provides + * sensible default implementations for methods of {@link PValue}. + * In particular, this includes functionality for getting/setting: + * + *

    + *
  • The {@code Pipeline} that the {@code PValue} is + * part of. + *
  • Whether the {@code PValue} has bee finalized (as an input + * or an output), after which its properties can + * no longer be changed. + *
+ * + *

For internal use. + */ +public abstract class PValueBase extends POutputValueBase implements PValue { + /** + * Returns the name of this {@code PValueBase}. + * + *

By default, the name of a {@code PValueBase} is based on the + * name of the {@code PTransform} that produces it. It can be + * specified explicitly by calling {@link #setName}. + * + * @throws IllegalStateException if the name hasn't been set yet + */ + public String getName() { + if (name == null) { + throw new IllegalStateException("name not set"); + } + return name; + } + + /** + * Sets the name of this {@code PValueBase}. Returns {@code this}. + * + * @throws IllegalStateException if this {@code PValueBase} has + * already been finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + public PValueBase setName(String name) { + if (finishedSpecifying) { + throw new IllegalStateException( + "cannot change the name of " + this + " once it's been used"); + } + this.name = name; + return this; + } + + ///////////////////////////////////////////////////////////////////////////// + + protected PValueBase() {} + + /** + * The name of this {@code PValueBase}, or null if not yet set. + */ + private String name; + + /** + * The {@code Pipeline} that owns this {@code PValueBase}, or null + * if not yet set. + */ + private Pipeline pipeline; + + /** + * Whether this {@code PValueBase} has been finalized, and its core + * properties, e.g., name, can no longer be changed. + */ + private boolean finishedSpecifying = false; + + + /** + * Returns the owning {@code Pipeline} of this {@code PValueBase}. + * + * @throws IllegalStateException if the owning {@code Pipeline} + * hasn't been set yet + */ + @Override + public Pipeline getPipeline() { + if (pipeline == null) { + throw new IllegalStateException("owning pipeline not set"); + } + return pipeline; + } + + /** + * Sets the owning {@code Pipeline} of this {@code PValueBase}. + * Returns {@code this}. + * + *

For internal use only. + * + * @throws IllegalArgumentException if the owner has already been set + * differently + */ + @Override + public PValue setPipelineInternal(Pipeline pipeline) { + if (this.pipeline != null + && this.pipeline != pipeline) { + throw new IllegalArgumentException( + "owning pipeline cannot be changed once set"); + } + this.pipeline = pipeline; + return this; + } + + @Override + public void recordAsOutput(Pipeline pipeline, + PTransform transform) { + recordAsOutput(pipeline, transform, "out"); + } + + /** + * Records that this {@code POutputValueBase} is an output with the + * given name of the given {@code PTransform} in the given + * {@code Pipeline}. + * + *

To be invoked only by {@link POutput#recordAsOutput} + * implementations. Not to be invoked directly by user code. + */ + protected void recordAsOutput(Pipeline pipeline, + PTransform transform, + String outName) { + super.recordAsOutput(pipeline, transform); + if (name == null) { + name = pipeline.getFullName(transform) + "." + outName; + } + } + + /** + * Returns whether this {@code PValueBase} has been finalized, and + * its core properties, e.g., name, can no longer be changed. + * + *

For internal use only. + */ + public boolean isFinishedSpecifyingInternal() { + return finishedSpecifying; + } + + @Override + public Collection expand() { + return Collections.singletonList(this); + } + + @Override + public void finishSpecifying() { + getProducingTransformInternal().finishSpecifying(); + finishedSpecifying = true; + } + + @Override + public String toString() { + return (name == null ? "" : getName()) + + " [" + getKindString() + "]"; + } + + /** + * Returns a {@code String} capturing the kind of this + * {@code PValueBase}. + * + *

By default, uses the base name of this {@code PValueBase}'s + * class as its kind string. + */ + protected String getKindString() { + return StringUtils.approximateSimpleName(getClass()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java new file mode 100644 index 000000000000..9d91a18cb3cf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; + +/** + * An immutable (value, timestamp) pair. + * + *

Used for assigning initial timestamps to values inserted into a pipeline + * with {@link com.google.cloud.dataflow.sdk.transforms.Create#timestamped}. + * + * @param the type of the value + */ +public class TimestampedValue { + + /** + * Returns a new {@code TimestampedValue} with the given value and timestamp. + */ + public static TimestampedValue of(V value, Instant timestamp) { + return new TimestampedValue<>(value, timestamp); + } + + public V getValue() { + return value; + } + + public Instant getTimestamp() { + return timestamp; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Coder for {@code TimestampedValue}. + */ + public static class TimestampedValueCoder + extends StandardCoder> { + + private final Coder valueCoder; + + public static TimestampedValueCoder of(Coder valueCoder) { + return new TimestampedValueCoder<>(valueCoder); + } + + @JsonCreator + public static TimestampedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + @SuppressWarnings("unchecked") + TimestampedValueCoder(Coder valueCoder) { + this.valueCoder = checkNotNull(valueCoder); + } + + @Override + public void encode(TimestampedValue windowedElem, + OutputStream outStream, + Context context) + throws IOException { + valueCoder.encode(windowedElem.getValue(), outStream, context.nested()); + InstantCoder.of().encode( + windowedElem.getTimestamp(), outStream, context); + } + + @Override + public TimestampedValue decode(InputStream inStream, Context context) + throws IOException { + T value = valueCoder.decode(inStream, context.nested()); + Instant timestamp = InstantCoder.of().decode(inStream, context); + return TimestampedValue.of(value, timestamp); + } + + @Override + public boolean isDeterministic() { + return valueCoder.isDeterministic(); + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + + public static List getInstanceComponents(TimestampedValue exampleValue) { + return Arrays.asList(exampleValue.getValue()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private final V value; + private final Instant timestamp; + + protected TimestampedValue(V value, Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java new file mode 100644 index 000000000000..58562163f4a8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java @@ -0,0 +1,170 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.reflect.TypeToken; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Random; + +/** + * A {@code TupleTag} is a typed tag to use as the key of a + * heterogeneously typed tuple, like {@link PCollectionTuple} or + * Its generic type parameter allows tracking + * the static type of things stored in tuples. + * + *

To aid in assigning default {@code Coder}s for results of + * side outputs of {@code ParDo}, an output + * {@code TupleTag} should be instantiated with an extra {@code {}} so + * it is an instance of an anonymous subclass without generic type + * parameters. Input {@code TupleTag}s require no such extra + * instantiation (although it doesn't hurt). For example: + * + *

 {@code
+ * TupleTag inputTag = new TupleTag<>();
+ * TupleTag outputTag = new TupleTag(){};
+ * } 
+ * + * @param the type of the elements or values of the tagged thing, + * e.g., a {@code PCollection}. + */ +public class TupleTag implements Serializable { + /** + * Constructs a new {@code TupleTag}, with a fresh unique id. + * + *

This is the normal way {@code TupleTag}s are constructed. + */ + public TupleTag() { + this.id = genId(); + this.generated = true; + } + + /** + * Constructs a new {@code TupleTag} with the given id. + * + *

It is up to the user to ensure that two {@code TupleTag}s + * with the same id actually mean the same tag and carry the same + * generic type parameter. Violating this invariant can lead to + * hard-to-diagnose runtime type errors. Consequently, this + * operation should be used very sparingly, such as when the + * producer and consumer of {@code TupleTag}s are written in + * separate modules and can only coordinate via ids rather than + * shared {@code TupleTag} instances. Most of the time, + * {@link #TupleTag()} should be preferred. + */ + public TupleTag(String id) { + this.id = id; + this.generated = false; + } + + /** + * Returns the id of this {@code TupleTag}. + * + *

Two {@code TupleTag}s with the same id are considered equal. + * + *

{@code TupleTag}s are not ordered, i.e., the class does not implement + * Comparable interface. TupleTags implement equals and hashCode, making them + * suitable for use as keys in HashMap and HashSet. + */ + public String getId() { return id; } + + /** + * If this {@code TupleTag} is tagging output {@code outputIndex} of + * a {@code PTransform}, returns the name that should be used by + * default for the output. + */ + public String getOutName(int outIndex) { + if (generated) { + return "out" + outIndex; + } else { + return id; + } + } + + /** + * Returns a {@code TypeToken} capturing what is known statically + * about the type of this {@code TupleTag} instance's most-derived + * class. + * + *

This is useful for a {@code TupleTag} constructed as an + * instance of an anonymous subclass with a trailing {@code {}}, + * e.g., {@code new TupleTag(){}}. + */ + public TypeToken getTypeToken() { + return new TypeToken(getClass()) {}; + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + static final Random RANDOM = new Random(0); + + final String id; + final boolean generated; + + /** Generates and returns a fresh unique id for a TupleTag's id. */ + static String genId() { + long randomLong; + synchronized (RANDOM) { + randomLong = RANDOM.nextLong(); + } + return Long.toHexString(randomLong); + } + + @JsonCreator + private static TupleTag fromJson( + @JsonProperty(PropertyNames.VALUE) String id, + @JsonProperty(PropertyNames.IS_GENERATED) boolean generated) { + return new TupleTag(id, generated); + } + + private TupleTag(String id, boolean generated) { + this.id = id; + this.generated = generated; + } + + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + addString(result, PropertyNames.VALUE, id); + addBoolean(result, PropertyNames.IS_GENERATED, generated); + return result; + } + + @Override + public boolean equals(Object that) { + if (that instanceof TupleTag) { + return this.id.equals(((TupleTag) that).id); + } else { + return false; + } + } + + @Override + public int hashCode() { return id.hashCode(); } + + @Override + public String toString() { return "Tag<" + id + ">"; } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java new file mode 100644 index 000000000000..27a0683bab5a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.common.collect.ImmutableList; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@code TupleTagList} is an immutable list of heterogeneously + * typed {@link TupleTag}s. A TupleTagList is used, for instance, to + * specify the tags of the side outputs of a + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo}. + * + *

TupleTagLists can be created and accessed like follows: + *

 {@code
+ * TupleTag tag1 = ...;
+ * TupleTag tag2 = ...;
+ * TupleTag> tag3 = ...;
+ *
+ * // Create a TupleTagList with three TupleTags:
+ * TupleTagList tags = TupleTagList.of(tag1).and(tag2).and(tag3);
+ *
+ * // Create an empty TupleTagList:
+ * Pipeline p = ...;
+ * TupleTagList tags2 = TupleTagList.empty(p);
+ *
+ * // Get TupleTags out of a TupleTagList, by index (origin 0):
+ * TupleTag tagX = tags.get(1);
+ * TupleTag tagY = tags.get(0);
+ * TupleTag tagZ = tags.get(2);
+ *
+ * // Get a list of all TupleTags in a TupleTagList:
+ * List> allTags = tags.getAll();
+ * } 
+ */ +public class TupleTagList implements Serializable { + /** + * Returns an empty TupleTagList. + * + *

Longer TupleTagLists can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList empty() { + return new TupleTagList(); + } + + /** + * Returns a singleton TupleTagList containing the given TupleTag. + * + *

Longer TupleTagLists can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList of(TupleTag tag) { + return empty().and(tag); + } + + /** + * Returns a TupleTagList containing the given TupleTags, in order. + * + *

Longer TupleTagLists can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList of(List> tags) { + return empty().and(tags); + } + + /** + * Returns a new TupleTagList that has all the TupleTags of + * this TupleTagList plus the given TupleTag appended to the end. + */ + public TupleTagList and(TupleTag tag) { + return new TupleTagList( + new ImmutableList.Builder>() + .addAll(tupleTags) + .add(tag) + .build()); + } + + /** + * Returns a new TupleTagList that has all the TupleTags of + * this TupleTagList plus the given TupleTags appended to the end, + * in order. + */ + public TupleTagList and(List> tags) { + return new TupleTagList( + new ImmutableList.Builder>() + .addAll(tupleTags) + .addAll(tags) + .build()); + } + + /** + * Returns the number of TupleTags in this TupleTagList. + */ + public int size() { + return tupleTags.size(); + } + + /** + * Returns the TupleTag at the given index (origin zero). Throws + * IndexOutOfBounds if the index is out of the range + * {@code [0..size()-1]}. + */ + public TupleTag get(int index) { + return tupleTags.get(index); + } + + /** + * Returns an immutable List of all the TupleTags in this TupleTagList. + */ + public List> getAll() { + return tupleTags; + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + final List> tupleTags; + + TupleTagList() { + this(new ArrayList>()); + } + + TupleTagList(List> tupleTags) { + this.tupleTags = Collections.unmodifiableList(tupleTags); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java new file mode 100644 index 000000000000..95b9b45f5377 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.common.reflect.TypeToken; + +/** + * A {@code TypedPValue} is the abstract base class of things that + * store some number of values of type {@code T}. Because we know + * the type {@code T}, this is the layer of the inheritance hierarchy where + * we store a coder for objects of type {@code T} + * + * @param the type of the values stored in this {@code TypedPValue} + */ +public abstract class TypedPValue extends PValueBase implements PValue { + + /** + * Returns the Coder used by this TypedPValue to encode and decode + * the values stored in it. + * + * @throws IllegalStateException if the Coder hasn't been set, and + * couldn't be inferred + */ + public Coder getCoder() { + if (coder == null) { + throw new IllegalStateException( + "coder for " + this + " not set, and couldn't be inferred; " + + "either register a default Coder for its element type, " + + "or use setCoder() to specify one explicitly"); + } + return coder; + } + + /** + * Sets the Coder used by this TypedPValue to encode and decode the + * values stored in it. Returns {@code this}. + * + * @throws IllegalStateException if this TypedPValue has already + * been finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + public TypedPValue setCoder(Coder coder) { + if (isFinishedSpecifyingInternal()) { + throw new IllegalStateException( + "cannot change the Coder of " + this + " once it's been used"); + } + if (coder == null) { + throw new IllegalArgumentException( + "Cannot setCoder(null)"); + } + this.coder = coder; + return this; + } + + @Override + public void recordAsOutput(Pipeline pipeline, + PTransform transform, + String outName) { + super.recordAsOutput(pipeline, transform, outName); + pipeline.addValueInternal(this); + } + + @Override + public TypedPValue setPipelineInternal(Pipeline pipeline) { + super.setPipelineInternal(pipeline); + return this; + } + + /** + * After building, finalizes this PValue to make it ready for + * running. Automatically invoked whenever the PValue is "used" + * (e.g., when apply() is called on it) and when the Pipeline is + * run (useful if this is a PValue with no consumers). + */ + @Override + public void finishSpecifying() { + if (isFinishedSpecifyingInternal()) { + return; + } + super.finishSpecifying(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + /** + * The Coder used by this TypedPValue to encode and decode the + * values stored in it, or null if not specified nor inferred yet. + */ + private Coder coder; + + protected TypedPValue() {} + + private TypeToken typeToken; + + /** + * Returns a {@code TypeToken} with some reflective information + * about {@code T}, if possible. May return {@code null} if no information + * is available. Subclasses may override this to enable better + * {@code Coder} inference. + */ + public TypeToken getTypeToken() { + return typeToken; + } + + /** + * Sets the {@code TypeToken} associated with this class. Better + * reflective type information will lead to better {@code Coder} + * inference. + */ + public TypedPValue setTypeTokenInternal(TypeToken typeToken) { + this.typeToken = typeToken; + return this; + } + + + /** + * If the coder is not explicitly set, this sets the coder for + * this {@code TypedPValue} to the best coder that can be inferred + * based upon the known {@code TypeToken}. By default, this is null, + * but can and should be improved by subclasses. + */ + @Override + public void finishSpecifyingOutput() { + if (coder == null) { + TypeToken token = getTypeToken(); + CoderRegistry registry = getProducingTransformInternal() + .getPipeline() + .getCoderRegistry(); + + if (token != null) { + coder = registry.getDefaultCoder(token); + } + + if (coder == null) { + coder = getProducingTransformInternal().getDefaultOutputCoder(this); + } + + if (coder == null) { + throw new IllegalStateException( + "unable to infer a default Coder for " + this + + "; either register a default Coder for its element type, " + + "or use setCoder() to specify one explicitly. " + + "If a default coder is registered, it may not be found " + + "due to type erasure; again, use setCoder() to specify " + + "a Coder explicitly"); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java new file mode 100644 index 000000000000..ba6e927e0996 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.values.PCollection} and other classes for + * representing data in a {@link com.google.cloud.dataflow.sdk.Pipeline}. + * + *

A {@link com.google.cloud.dataflow.sdk.values.PCollection} is an immutable collection of + * values of type {@code T} and is the main representation for data. + * A {@link com.google.cloud.dataflow.sdk.values.PCollectionTuple} is a tuple of PCollections + * used in cases where PTransforms take or return multiple PCollections. + * + *

A {@link com.google.cloud.dataflow.sdk.values.PCollectionTuple} is an immutable tuple of + * heterogeneously-typed {@link com.google.cloud.dataflow.sdk.values.PCollection}s, "keyed" by + * {@link com.google.cloud.dataflow.sdk.values.TupleTag}s. + * A PCollectionTuple can be used as the input or + * output of a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform} taking + * or producing multiple PCollection inputs or outputs that can be of + * different types, for instance a + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} with side + * outputs. + * + *

A {@link com.google.cloud.dataflow.sdk.values.PCollectionView} is an immutable view of a + * PCollection that can be accessed from a DoFn and other user Fns + * as a side input. + * + */ +package com.google.cloud.dataflow.sdk.values; diff --git a/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties b/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties new file mode 100644 index 000000000000..5b0a720b215d --- /dev/null +++ b/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties @@ -0,0 +1,5 @@ +# SDK source version. +version=${pom.version} + +build.date=${timestamp} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java new file mode 100644 index 000000000000..13d2b2996cfd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.util.UserCodeException; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Pipeline. + */ +@RunWith(JUnit4.class) +public class PipelineTest { + + static class PipelineWrapper extends Pipeline { + protected PipelineWrapper(PipelineRunner runner) { + super(runner, PipelineOptionsFactory.create()); + } + } + + // Mock class that throws a user code exception during the call to + // Pipeline.run(). + static class TestPipelineRunnerThrowingUserException + extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + Throwable t = new IllegalStateException("user code exception"); + throw new UserCodeException(t); + } + } + + // Mock class that throws an SDK or API client code exception during + // the call to Pipeline.run(). + static class TestPipelineRunnerThrowingSDKException + extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + throw new IllegalStateException("SDK exception"); + } + } + + @Test + public void testPipelineUserExceptionHandling() { + Pipeline p = new PipelineWrapper( + new TestPipelineRunnerThrowingUserException()); + + // Check pipeline runner correctly catches user errors. + try { + Object results = p.run(); + fail("Should have thrown an exception."); + } catch (RuntimeException exn) { + // Make sure users don't have to worry about the + // UserCodeException wrapper. + Assert.assertThat(exn, not(instanceOf(UserCodeException.class))); + // Assert that the message is correct. + Assert.assertThat( + exn.getMessage(), containsString("user code exception")); + // Cause should be IllegalStateException. + Assert.assertThat( + exn.getCause(), instanceOf(IllegalStateException.class)); + } + } + + @Test + public void testPipelineSDKExceptionHandling() { + Pipeline p = new PipelineWrapper(new TestPipelineRunnerThrowingSDKException()); + + // Check pipeline runner correctly catches SDK errors. + try { + Object results = p.run(); + fail("Should have thrown an exception."); + } catch (RuntimeException exn) { + // Make sure the exception isn't a UserCodeException. + Assert.assertThat(exn, not(instanceOf(UserCodeException.class))); + // Assert that the message is correct. + Assert.assertThat(exn.getMessage(), containsString("SDK exception")); + // RuntimeException should be IllegalStateException. + Assert.assertThat(exn, instanceOf(IllegalStateException.class)); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java new file mode 100644 index 000000000000..9a92ba0167c4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java @@ -0,0 +1,231 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Utilities for tests. + */ +public class TestUtils { + // Do not instantiate. + private TestUtils() {} + + public static final String[] NO_LINES_ARRAY = new String[] { }; + + public static final List NO_LINES = Arrays.asList(NO_LINES_ARRAY); + + public static final String[] LINES_ARRAY = new String[] { + "To be, or not to be: that is the question: ", + "Whether 'tis nobler in the mind to suffer ", + "The slings and arrows of outrageous fortune, ", + "Or to take arms against a sea of troubles, ", + "And by opposing end them? To die: to sleep; ", + "No more; and by a sleep to say we end ", + "The heart-ache and the thousand natural shocks ", + "That flesh is heir to, 'tis a consummation ", + "Devoutly to be wish'd. To die, to sleep; ", + "To sleep: perchance to dream: ay, there's the rub; ", + "For in that sleep of death what dreams may come ", + "When we have shuffled off this mortal coil, ", + "Must give us pause: there's the respect ", + "That makes calamity of so long life; ", + "For who would bear the whips and scorns of time, ", + "The oppressor's wrong, the proud man's contumely, ", + "The pangs of despised love, the law's delay, ", + "The insolence of office and the spurns ", + "That patient merit of the unworthy takes, ", + "When he himself might his quietus make ", + "With a bare bodkin? who would fardels bear, ", + "To grunt and sweat under a weary life, ", + "But that the dread of something after death, ", + "The undiscover'd country from whose bourn ", + "No traveller returns, puzzles the will ", + "And makes us rather bear those ills we have ", + "Than fly to others that we know not of? ", + "Thus conscience does make cowards of us all; ", + "And thus the native hue of resolution ", + "Is sicklied o'er with the pale cast of thought, ", + "And enterprises of great pith and moment ", + "With this regard their currents turn awry, ", + "And lose the name of action.--Soft you now! ", + "The fair Ophelia! Nymph, in thy orisons ", + "Be all my sins remember'd." }; + + public static final List LINES = Arrays.asList(LINES_ARRAY); + + public static final String[] LINES2_ARRAY = new String[] { + "hi", "there", "bob!" }; + + public static final List LINES2 = Arrays.asList(LINES2_ARRAY); + + public static final Integer[] NO_INTS_ARRAY = new Integer[] { }; + + public static final List NO_INTS = Arrays.asList(NO_INTS_ARRAY); + + public static final Integer[] INTS_ARRAY = new Integer[] { + 3, 42, Integer.MAX_VALUE, 0, -1, Integer.MIN_VALUE, 666 }; + + public static final List INTS = Arrays.asList(INTS_ARRAY); + + /** + * Matcher for KVs. + */ + public static class KvMatcher + extends TypeSafeMatcher> { + final Matcher keyMatcher; + final Matcher valueMatcher; + + public static KvMatcher isKv(Matcher keyMatcher, + Matcher valueMatcher) { + return new KvMatcher<>(keyMatcher, valueMatcher); + } + + public KvMatcher(Matcher keyMatcher, + Matcher valueMatcher) { + this.keyMatcher = keyMatcher; + this.valueMatcher = valueMatcher; + } + + @Override + public boolean matchesSafely(KV kv) { + return keyMatcher.matches(kv.getKey()) + && valueMatcher.matches(kv.getValue()); + } + + @Override + public void describeTo(Description description) { + description + .appendText("a KV(").appendValue(keyMatcher) + .appendText(", ").appendValue(valueMatcher) + .appendText(")"); + } + } + + public static PCollection createStrings(Pipeline p, + Iterable values) { + return p.apply(Create.of(values)).setCoder(StringUtf8Coder.of()); + } + + public static PCollection createInts(Pipeline p, + Iterable values) { + return p.apply(Create.of(values)).setCoder(BigEndianIntegerCoder.of()); + } + + public static PCollectionView + createSingletonInt(Pipeline p, Integer value) { + PCollection collection = p.apply(Create.of(value)); + return collection.apply(View.asSingleton()); + } + + //////////////////////////////////////////////////////////////////////////// + // Utilities for testing CombineFns, ensuring they give correct results + // across various permutations and shardings of the input. + + public static void checkCombineFn( + CombineFn fn, List input, final VO expected) { + checkCombineFn(fn, input, CoreMatchers.is(expected)); + } + + public static void checkCombineFn( + CombineFn fn, List input, Matcher matcher) { + checkCombineFnInternal(fn, input, matcher); + Collections.shuffle(input); + checkCombineFnInternal(fn, input, matcher); + } + + private static void checkCombineFnInternal( + CombineFn fn, List input, Matcher matcher) { + int size = input.size(); + checkCombineFnShards(fn, Collections.singletonList(input), matcher); + checkCombineFnShards(fn, shardEvenly(input, 2), matcher); + if (size > 4) { + checkCombineFnShards(fn, shardEvenly(input, size / 2), matcher); + checkCombineFnShards( + fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher); + } + checkCombineFnShards(fn, shardExponentially(input, 1.4), matcher); + checkCombineFnShards(fn, shardExponentially(input, 2), matcher); + checkCombineFnShards(fn, shardExponentially(input, Math.E), matcher); + } + + public static void checkCombineFnShards( + CombineFn fn, + List> shards, + Matcher matcher) { + checkCombineFnShardsInternal(fn, shards, matcher); + Collections.shuffle(shards); + checkCombineFnShardsInternal(fn, shards, matcher); + } + + private static void checkCombineFnShardsInternal( + CombineFn fn, + Iterable> shards, + Matcher matcher) { + List accumulators = new ArrayList<>(); + for (Iterable shard : shards) { + VA accumulator = fn.createAccumulator(); + for (VI elem : shard) { + fn.addInput(accumulator, elem); + } + accumulators.add(accumulator); + } + VA merged = fn.mergeAccumulators(accumulators); + assertThat(fn.extractOutput(merged), matcher); + } + + private static List> shardEvenly(List input, int numShards) { + List> shards = new ArrayList<>(numShards); + for (int i = 0; i < numShards; i++) { + shards.add(input.subList(i * input.size() / numShards, + (i + 1) * input.size() / numShards)); + } + return shards; + } + + private static List> shardExponentially( + List input, double base) { + assert base > 1.0; + List> shards = new ArrayList<>(); + int end = input.size(); + while (end > 0) { + int start = (int) (end / base); + shards.add(input.subList(start, end)); + end = start; + } + return shards; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java new file mode 100644 index 000000000000..725c0e852022 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + +/** + * Tests for AvroCoder. + */ +@RunWith(JUnit4.class) +public class AvroCoderTest { + + @DefaultCoder(AvroCoder.class) + private static class Pojo { + public String text; + public int count; + + // Empty constructor required for Avro decoding. + public Pojo() { + } + + public Pojo(String text, int count) { + this.text = text; + this.count = count; + } + + // auto-generated + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Pojo pojo = (Pojo) o; + + if (count != pojo.count) { + return false; + } + if (text != null + ? !text.equals(pojo.text) + : pojo.text != null) { + return false; + } + + return true; + } + + @Override + public String toString() { + return "Pojo{" + + "text='" + text + '\'' + + ", count=" + count + + '}'; + } + } + + static class GetTextFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().text); + } + } + + @Test + public void testAvroCoderEncoding() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + CloudObject encoding = coder.asCloudObject(); + + Assert.assertThat(encoding.keySet(), + Matchers.containsInAnyOrder("@type", "type", "schema")); + } + + @Test + public void testPojoEncoding() throws Exception { + Pojo before = new Pojo("Hello", 42); + + AvroCoder coder = AvroCoder.of(Pojo.class); + byte[] bytes = CoderUtils.encodeToByteArray(coder, before); + Pojo after = CoderUtils.decodeFromByteArray(coder, bytes); + + Assert.assertEquals(before, after); + } + + @Test + public void testGenericRecordEncoding() throws Exception { + String schemaString = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"User\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + Schema schema = (new Schema.Parser()).parse(schemaString); + + GenericRecord before = new GenericData.Record(schema); + before.put("name", "Bob"); + before.put("favorite_number", 256); + // Leave favorite_color null + + AvroCoder coder = AvroCoder.of(GenericRecord.class, schema); + byte[] bytes = CoderUtils.encodeToByteArray(coder, before); + GenericRecord after = CoderUtils.decodeFromByteArray(coder, bytes); + + Assert.assertEquals(before, after); + + Assert.assertEquals(schema, coder.getSchema()); + } + + @Test + public void testEncodingNotBuffered() throws Exception { + // This test ensures that the coder doesn't read ahead and buffer data. + // Reading ahead causes a problem if the stream consists of records of different + // types. + Pojo before = new Pojo("Hello", 42); + + AvroCoder coder = AvroCoder.of(Pojo.class); + SerializableCoder intCoder = SerializableCoder.of(Integer.class); + + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + + Context context = Context.NESTED; + coder.encode(before, outStream, context); + intCoder.encode(10, outStream, context); + + ByteArrayInputStream inStream = new ByteArrayInputStream(outStream.toByteArray()); + + Pojo after = coder.decode(inStream, context); + Assert.assertEquals(before, after); + + Integer intAfter = intCoder.decode(inStream, context); + Assert.assertEquals(new Integer(10), intAfter); + } + + @Test + public void testDefaultCoder() throws Exception { + Pipeline p = TestPipeline.create(); + + // Use MyRecord as input and output types without explicitly specifying + // a coder (this uses the default coders, which may not be AvroCoder). + PCollection output = + p.apply(Create.of(new Pojo("hello", 1), new Pojo("world", 2))) + .apply(ParDo.of(new GetTextFn())); + + DataflowAssert.that(output) + .containsInAnyOrder("hello", "world"); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java new file mode 100644 index 000000000000..b6d2b3c657d0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.util.common.CounterTestUtils; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** Unit tests for {@link ByteArrayCoder}. */ +@RunWith(JUnit4.class) +public class ByteArrayCoderTest { + @Test public void testOuterContext() throws CoderException, IOException { + byte[] buffer = {0xa, 0xb, 0xc}; + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + ByteArrayCoder.of().encode(buffer, os, Coder.Context.OUTER); + byte[] encoded = os.toByteArray(); + + ByteArrayInputStream is = new ByteArrayInputStream(encoded); + byte[] decoded = ByteArrayCoder.of().decode(is, Coder.Context.OUTER); + assertThat(decoded, equalTo(buffer)); + } + + @Test public void testNestedContext() throws CoderException, IOException { + byte[][] buffers = {{0xa, 0xb, 0xc}, {}, {}, {0xd, 0xe}, {}}; + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + for (byte[] buffer : buffers) { + ByteArrayCoder.of().encode(buffer, os, Coder.Context.NESTED); + } + byte[] encoded = os.toByteArray(); + + ByteArrayInputStream is = new ByteArrayInputStream(encoded); + for (byte[] buffer : buffers) { + byte[] decoded = ByteArrayCoder.of().decode(is, Coder.Context.NESTED); + assertThat(decoded, equalTo(buffer)); + } + } + + @Test public void testRegisterByteSizeObserver() throws Exception { + CounterTestUtils.testByteCount(ByteArrayCoder.of(), Coder.Context.OUTER, + new byte[][]{{ 0xa, 0xb, 0xc }}); + + CounterTestUtils.testByteCount(ByteArrayCoder.of(), Coder.Context.NESTED, + new byte[][]{{ 0xa, 0xb, 0xc }, {}, {}, { 0xd, 0xe }, {}}); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProperties.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProperties.java new file mode 100644 index 000000000000..ef096eb01c99 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProperties.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeThat; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** + * Properties for use in {@link Coder} tests. These are implemented with junit assertions + * rather than as predicates for the sake of error messages. + */ +public class CoderProperties { + + /** + * Verifies that for the given {@link Coder}, {@link Coder.Context}, and values of + * type {@code T}, if the values are equal then the encoded bytes are equal. + */ + public static void coderDeterministic( + Coder coder, Coder.Context context, T value1, T value2) + throws Exception { + assumeThat(value1, equalTo(value2)); + assertArrayEquals(encode(coder, context, value1), encode(coder, context, value2)); + } + + /** + * Verifies that for the given {@link Coder}, {@link Coder.Context}, + * and value of type {@code T}, encoding followed by decoding yields an + * equal of type {@code T}. + */ + public static void coderDecodeEncodeEqual( + Coder coder, Coder.Context context, T value) + throws Exception { + assertEquals( + decode(coder, context, encode(coder, context, value)), + value); + } + + ////////////////////////////////////////////////////////////////////////// + + private static byte[] encode( + Coder coder, Coder.Context context, T value) throws CoderException, IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + coder.encode(value, os, context); + return os.toByteArray(); + } + + private static T decode( + Coder coder, Coder.Context context, byte[] bytes) throws CoderException, IOException { + ByteArrayInputStream is = new ByteArrayInputStream(bytes); + return coder.decode(is, context); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java new file mode 100644 index 000000000000..ace309482733 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java @@ -0,0 +1,230 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.reflect.TypeToken; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Tests for CoderRegistry. + */ +@RunWith(JUnit4.class) +public class CoderRegistryTest { + + public static CoderRegistry getStandardRegistry() { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + return registry; + } + + @Test + public void testRegisterInstantiatedGenericCoder() { + class MyValueList extends ArrayList { } + + CoderRegistry registry = new CoderRegistry(); + registry.registerCoder(MyValueList.class, ListCoder.of(MyValueCoder.of())); + assertEquals(registry.getDefaultCoder(MyValueList.class), ListCoder.of(MyValueCoder.of())); + } + + @Test + public void testSimpleDefaultCoder() { + CoderRegistry registry = getStandardRegistry(); + assertEquals(StringUtf8Coder.of(), registry.getDefaultCoder(String.class)); + assertEquals(null, registry.getDefaultCoder(UnknownType.class)); + } + + @Test + public void testTemplateDefaultCoder() { + CoderRegistry registry = getStandardRegistry(); + TypeToken> listToken = new TypeToken>() {}; + assertEquals(ListCoder.of(VarIntCoder.of()), + registry.getDefaultCoder(listToken)); + + registry.registerCoder(MyValue.class, MyValueCoder.class); + TypeToken>> kvToken = + new TypeToken>>() {}; + assertEquals(KvCoder.of(StringUtf8Coder.of(), + ListCoder.of(MyValueCoder.of())), + registry.getDefaultCoder(kvToken)); + + TypeToken> listUnknownToken = + new TypeToken>() {}; + assertEquals(null, registry.getDefaultCoder(listUnknownToken)); + } + + @Test + public void testTemplateInference() { + CoderRegistry registry = getStandardRegistry(); + MyTemplateClass> instance = + new MyTemplateClass>() {}; + Coder> expected = ListCoder.of(MyValueCoder.of()); + + // The map method operates on parameter names. + Map> coderMap = registry.getDefaultCoders( + instance.getClass(), + MyTemplateClass.class, + Collections.singletonMap("A", MyValueCoder.of())); + assertEquals(expected, coderMap.get("B")); + + // The array interface operates on position. + Coder[] coders = registry.getDefaultCoders( + instance.getClass(), + MyTemplateClass.class, + new Coder[] { MyValueCoder.of(), null }); + assertEquals(expected, coders[1]); + + // The "last argument" coder handles a common case. + Coder> actual = registry.getDefaultCoder( + instance.getClass(), + MyTemplateClass.class, + MyValueCoder.of()); + assertEquals(expected, actual); + + try { + registry.getDefaultCoder( + instance.getClass(), + MyTemplateClass.class, + BigEndianIntegerCoder.of()); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + assertEquals("Cannot encode elements of type class " + + "com.google.cloud.dataflow.sdk.coders.CoderRegistryTest$MyValue " + + "with BigEndianIntegerCoder", exn.getMessage()); + } + } + + @Test + public void testGetDefaultCoderFromIntegerValue() { + CoderRegistry registry = getStandardRegistry(); + Integer i = 13; + Coder coder = registry.getDefaultCoder(i); + assertEquals(VarIntCoder.of(), coder); + } + + @Test + public void testGetDefaultCoderFromKvValue() { + CoderRegistry registry = getStandardRegistry(); + KV kv = KV.of(13, "hello"); + Coder> coder = registry.getDefaultCoder(kv); + assertEquals(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()), + coder); + } + + @Test + public void testGetDefaultCoderFromNestedKvValue() { + CoderRegistry registry = getStandardRegistry(); + KV>> kv = KV.of(13, KV.of(17L, KV.of("hello", "goodbye"))); + Coder>>> coder = registry.getDefaultCoder(kv); + assertEquals( + KvCoder.of(VarIntCoder.of(), + KvCoder.of(VarLongCoder.of(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))), + coder); + } + + @Test + public void testTypeCompatibility() { + assertTrue(CoderRegistry.isCompatible( + BigEndianIntegerCoder.of(), Integer.class)); + assertFalse(CoderRegistry.isCompatible( + BigEndianIntegerCoder.of(), String.class)); + + assertFalse(CoderRegistry.isCompatible( + ListCoder.of(BigEndianIntegerCoder.of()), Integer.class)); + assertTrue(CoderRegistry.isCompatible( + ListCoder.of(BigEndianIntegerCoder.of()), + new TypeToken>() {}.getType())); + assertFalse(CoderRegistry.isCompatible( + ListCoder.of(BigEndianIntegerCoder.of()), + new TypeToken>() {}.getType())); + } + + static class MyTemplateClass { } + + static class MyValue { } + + static class MyValueCoder implements Coder { + + private static final MyValueCoder INSTANCE = new MyValueCoder(); + + public static MyValueCoder of() { + return INSTANCE; + } + + public static List getInstanceComponents(MyValue exampleValue) { + return Arrays.asList(); + } + + @Override + public void encode(MyValue value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public MyValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new MyValue(); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + return null; + } + + @Override + public boolean isDeterministic() { return true; } + + @Override + public boolean isRegisterByteSizeObserverCheap(MyValue value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + MyValue value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(0L); + } + } + + static class UnknownType { } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java new file mode 100644 index 000000000000..e532d44dc66b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** Unit tests for {@link CustomCoder}. */ +@RunWith(JUnit4.class) +public class CustomCoderTest { + + private static class MyCustomCoder extends CustomCoder> { + private final String key; + + public MyCustomCoder(String key) { + this.key = key; + } + + @Override + public void encode(KV kv, OutputStream out, Context context) + throws IOException { + new DataOutputStream(out).writeLong(kv.getValue()); + } + + @Override + public KV decode(InputStream inStream, Context context) + throws IOException { + return KV.of(key, new DataInputStream(inStream).readLong()); + } + + @Override + public boolean equals(Object other) { + return other instanceof MyCustomCoder + && key.equals(((MyCustomCoder) other).key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } + + @Test public void testEncodeDecode() throws Exception { + MyCustomCoder coder = new MyCustomCoder("key"); + byte[] encoded = CoderUtils.encodeToByteArray(coder, KV.of("key", 3L)); + Assert.assertEquals( + KV.of("key", 3L), CoderUtils.decodeFromByteArray(coder, encoded)); + + byte[] encoded2 = CoderUtils.encodeToByteArray(coder, KV.of("ignored", 3L)); + Assert.assertEquals( + KV.of("key", 3L), CoderUtils.decodeFromByteArray(coder, encoded2)); + } + + @Test public void testEncodable() throws Exception { + SerializableUtils.ensureSerializable(new MyCustomCoder("key")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java new file mode 100644 index 000000000000..769d1e6fb144 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.client.util.Preconditions; +import com.google.common.reflect.TypeToken; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests of Coder defaults. + */ +@RunWith(JUnit4.class) +public class DefaultCoderTest { + + @DefaultCoder(AvroCoder.class) + private static class AvroRecord { + } + + private static class SerializableBase implements Serializable { + } + + @DefaultCoder(SerializableCoder.class) + private static class SerializableRecord extends SerializableBase { + } + + @DefaultCoder(CustomCoder.class) + private static class CustomRecord extends SerializableBase { + } + + private static class Unknown { + } + + private static class CustomCoder extends SerializableCoder { + // Extending SerializableCoder isn't trivial, but it can be done. + @SuppressWarnings("unchecked") + public static SerializableCoder of(Class recordType) { + Preconditions.checkArgument( + CustomRecord.class.isAssignableFrom(recordType)); + return (SerializableCoder) new CustomCoder(); + } + + protected CustomCoder() { + super(CustomRecord.class); + } + } + + @Test + public void testDefaultCoders() throws Exception { + checkDefault(AvroRecord.class, AvroCoder.class); + checkDefault(SerializableBase.class, SerializableCoder.class); + checkDefault(SerializableRecord.class, SerializableCoder.class); + checkDefault(CustomRecord.class, CustomCoder.class); + } + + @Test + public void testUnknown() throws Exception { + CoderRegistry registery = new CoderRegistry(); + Coder coderType = registery.getDefaultCoder(Unknown.class); + Assert.assertNull(coderType); + } + + /** + * Checks that the default Coder for {@code valueType} is an instance of + * {@code expectedCoder}. + */ + private void checkDefault(Class valueType, + Class expectedCoder) { + CoderRegistry registry = new CoderRegistry(); + Coder coder = registry.getDefaultCoder(TypeToken.of(valueType)); + Assert.assertThat(coder, Matchers.instanceOf(expectedCoder)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java new file mode 100644 index 000000000000..dd719004eab1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.primitives.UnsignedBytes; + +import org.joda.time.Instant; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** Unit tests for {@link InstantCoder}. */ +@RunWith(JUnit4.class) +public class InstantCoderTest { + private final InstantCoder coder = InstantCoder.of(); + private final List timestamps = + Arrays.asList(0L, 1L, -1L, -255L, 256L, Long.MIN_VALUE, Long.MAX_VALUE); + + @Test + public void testBasicEncoding() throws Exception { + for (long timestamp : timestamps) { + Assert.assertEquals(new Instant(timestamp), + CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, new Instant(timestamp)))); + } + } + + @Test + public void testOrderedEncoding() throws Exception { + List sortedTimestamps = new ArrayList<>(timestamps); + Collections.sort(sortedTimestamps); + + List encodings = new ArrayList<>(sortedTimestamps.size()); + for (long timestamp : sortedTimestamps) { + encodings.add(CoderUtils.encodeToByteArray(coder, new Instant(timestamp))); + } + + // Verify that the encodings were already sorted, since they were generated + // in the correct order. + List sortedEncodings = new ArrayList<>(encodings); + Collections.sort(sortedEncodings, UnsignedBytes.lexicographicalComparator()); + + Assert.assertEquals(encodings, sortedEncodings); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java new file mode 100644 index 000000000000..993c5d0a5e91 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link IterableCoder}. */ +@RunWith(JUnit4.class) +public class IterableCoderTest { + @Test + public void testGetInstanceComponentsNonempty() { + Iterable iterable = Arrays.asList(2, 58, 99, 5); + List components = IterableCoder.getInstanceComponents(iterable); + assertEquals(1, components.size()); + assertEquals(2, components.get(0)); + } + + @Test + public void testGetInstanceComponentsEmpty() { + Iterable iterable = Arrays.asList(); + List components = IterableCoder.getInstanceComponents(iterable); + assertNull(components); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java new file mode 100644 index 000000000000..c04d3e16745b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link ListCoder}. */ +@RunWith(JUnit4.class) +public class ListCoderTest { + @Test + public void testGetInstanceComponentsNonempty() { + List list = Arrays.asList(21, 5, 3, 5); + List components = ListCoder.getInstanceComponents(list); + assertEquals(1, components.size()); + assertEquals(21, components.get(0)); + } + + @Test + public void testGetInstanceComponentsEmpty() { + List list = Arrays.asList(); + List components = ListCoder.getInstanceComponents(list); + assertNull(components); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java new file mode 100644 index 000000000000..30cd0d8e8100 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Unit tests for {@link MapCoder}. */ +@RunWith(JUnit4.class) +public class MapCoderTest { + @Test + public void testGetInstanceComponentsNonempty() { + Map map = new HashMap<>(); + map.put(17, "foozle"); + List components = MapCoder.getInstanceComponents(map); + assertEquals(2, components.size()); + assertEquals(17, components.get(0)); + assertEquals("foozle", components.get(1)); + } + + @Test + public void testGetInstanceComponentsEmpty() { + Map map = new HashMap<>(); + List components = MapCoder.getInstanceComponents(map); + assertNull(components); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java new file mode 100644 index 000000000000..3e56832a3faa --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests SerializableCoder. + */ +@RunWith(JUnit4.class) +public class SerializableCoderTest implements Serializable { + + @DefaultCoder(SerializableCoder.class) + static class MyRecord implements Serializable { + public String value; + + public MyRecord(String value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + MyRecord myRecord = (MyRecord) o; + return value.equals(myRecord.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + } + + static class StringToRecord extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(new MyRecord(c.element())); + } + } + + static class RecordToString extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().value); + } + } + + static final List LINES = Arrays.asList( + "To be,", + "or not to be"); + + @Test + public void testSerializableCoder() throws Exception { + IterableCoder coder = IterableCoder + .of(SerializableCoder.of(MyRecord.class)); + + List records = new LinkedList<>(); + for (String l : LINES) { + records.add(new MyRecord(l)); + } + + byte[] encoded = CoderUtils.encodeToByteArray(coder, records); + Iterable decoded = CoderUtils.decodeFromByteArray(coder, encoded); + + assertEquals(records, decoded); + } + + @Test + public void testSerializableCoderConstruction() throws Exception { + SerializableCoder coder = SerializableCoder.of(MyRecord.class); + assertEquals(coder.getRecordType(), MyRecord.class); + + CloudObject encoding = coder.asCloudObject(); + Assert.assertThat(encoding.getClassName(), + Matchers.containsString(SerializableCoder.class.getSimpleName())); + + Coder decoded = Serializer.deserialize(encoding, Coder.class); + Assert.assertThat(decoded, Matchers.instanceOf(SerializableCoder.class)); + } + + @Test + public void testDefaultCoder() throws Exception { + Pipeline p = TestPipeline.create(); + + // Use MyRecord as input and output types without explicitly specifying + // a coder (this uses the default coders, which may not be + // SerializableCoder). + PCollection output = + p.apply(Create.of("Hello", "World")) + .apply(ParDo.of(new StringToRecord())) + .apply(ParDo.of(new RecordToString())); + + DataflowAssert.that(output) + .containsInAnyOrder("Hello", "World"); + } + + @Test + public void testLongStringEncoding() throws Exception { + StringUtf8Coder coder = StringUtf8Coder.of(); + + // Java's DataOutputStream.writeUTF fails at 64k, so test well beyond that. + char[] chars = new char[100 * 1024]; + Arrays.fill(chars, 'o'); + String source = new String(chars); + + // Verify OUTER encoding. + assertEquals(source, CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, source))); + + // Second string uses a UTF8 character. Each codepoint is translated into + // 4 characters in UTF8. + int[] codePoints = new int[20 * 1024]; + Arrays.fill(codePoints, 0x1D50A); // "MATHEMATICAL_FRAKTUR_CAPITAL_G" + String source2 = new String(codePoints, 0, codePoints.length); + + // Verify OUTER encoding. + assertEquals(source2, CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, source2))); + + + // Encode both strings into NESTED form. + byte[] nestedEncoding; + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + coder.encode(source, os, Coder.Context.NESTED); + coder.encode(source2, os, Coder.Context.NESTED); + nestedEncoding = os.toByteArray(); + } + + // Decode from NESTED form. + try (ByteArrayInputStream is = new ByteArrayInputStream(nestedEncoding)) { + String result = coder.decode(is, Coder.Context.NESTED); + String result2 = coder.decode(is, Coder.Context.NESTED); + assertEquals(0, is.available()); + assertEquals(source, result); + assertEquals(source2, result2); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/URICoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/URICoderTest.java new file mode 100644 index 000000000000..f464e813bf5d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/URICoderTest.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link URICoder}. */ +@RunWith(JUnit4.class) +public class URICoderTest { + + // Test data + + private static final List TEST_URI_STRINGS = Arrays.asList( + "http://www.example.com", + "gs://myproject/mybucket/a/gcs/path", + "/just/a/path", + "file:/path/with/no/authority", + "file:///path/with/empty/authority"); + + private static final List TEST_CONTEXTS = Arrays.asList( + Coder.Context.OUTER, + Coder.Context.NESTED); + + // Tests + + @Test + public void testDeterministic() throws Exception { + Coder coder = URICoder.of(); + + for (String uriString : TEST_URI_STRINGS) { + for (Coder.Context context : TEST_CONTEXTS) { + // Obviously equal, but distinct as objects + CoderProperties.coderDeterministic(coder, context, new URI(uriString), new URI(uriString)); + } + } + } + + @Test + public void testDecodeEncodeEqual() throws Exception { + Coder coder = URICoder.of(); + + for (String uriString : TEST_URI_STRINGS) { + for (Coder.Context context : TEST_CONTEXTS) { + CoderProperties.coderDecodeEncodeEqual(coder, context, new URI(uriString)); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java new file mode 100644 index 000000000000..ad6f16567e92 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java @@ -0,0 +1,365 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.specific.SpecificDatumWriter; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for AvroIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class AvroIOTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + private File avroFile; + + @Before + public void prepareAvroFileBeforeAnyTest() throws IOException { + avroFile = tmpFolder.newFile("file.avro"); + } + + private final String schemaString = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"User\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + private final Schema.Parser parser = new Schema.Parser(); + private final Schema schema = parser.parse(schemaString); + + private User[] generateAvroObjects() { + User user1 = new User(); + user1.setName("Bob"); + user1.setFavoriteNumber(256); + + User user2 = new User(); + user2.setName("Alice"); + user2.setFavoriteNumber(128); + + User user3 = new User(); + user3.setName("Ted"); + user3.setFavoriteColor("white"); + + return new User[] { user1, user2, user3 }; + } + + private GenericRecord[] generateAvroGenericRecords() { + GenericRecord user1 = new GenericData.Record(schema); + user1.put("name", "Bob"); + user1.put("favorite_number", 256); + + GenericRecord user2 = new GenericData.Record(schema); + user2.put("name", "Alice"); + user2.put("favorite_number", 128); + + GenericRecord user3 = new GenericData.Record(schema); + user3.put("name", "Ted"); + user3.put("favorite_color", "white"); + + return new GenericRecord[] { user1, user2, user3 }; + } + + private void generateAvroFile(User[] elements) throws IOException { + DatumWriter userDatumWriter = new SpecificDatumWriter<>(User.class); + DataFileWriter dataFileWriter = new DataFileWriter<>(userDatumWriter); + dataFileWriter.create(elements[0].getSchema(), avroFile); + for (User user : elements) { + dataFileWriter.append(user); + } + dataFileWriter.close(); + } + + private List readAvroFile() throws IOException { + DatumReader userDatumReader = new SpecificDatumReader<>(User.class); + DataFileReader dataFileReader = new DataFileReader<>(avroFile, userDatumReader); + List users = new ArrayList<>(); + while (dataFileReader.hasNext()) { + users.add(dataFileReader.next()); + } + return users; + } + + void runTestRead(AvroIO.Read.Bound read, String expectedName, T[] expectedOutput) + throws Exception { + generateAvroFile(generateAvroObjects()); + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection output = p.apply(read); + EvaluationResults results = p.run(); + assertEquals(expectedName, output.getName()); + assertThat(results.getPCollection(output), + containsInAnyOrder(expectedOutput)); + } + + @Test + public void testReadFromGeneratedClass() throws Exception { + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(User.class), + "AvroIO.Read.out", generateAvroObjects()); + runTestRead(AvroIO.Read.withSchema(User.class) + .from(avroFile.getPath()), + "AvroIO.Read.out", generateAvroObjects()); + runTestRead(AvroIO.Read.named("MyRead") + .from(avroFile.getPath()) + .withSchema(User.class), + "MyRead.out", generateAvroObjects()); + runTestRead(AvroIO.Read.named("MyRead") + .withSchema(User.class) + .from(avroFile.getPath()), + "MyRead.out", generateAvroObjects()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(User.class) + .named("HerRead"), + "HerRead.out", generateAvroObjects()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .named("HerRead") + .withSchema(User.class), + "HerRead.out", generateAvroObjects()); + runTestRead(AvroIO.Read.withSchema(User.class) + .named("HerRead") + .from(avroFile.getPath()), + "HerRead.out", generateAvroObjects()); + runTestRead(AvroIO.Read.withSchema(User.class) + .from(avroFile.getPath()) + .named("HerRead"), + "HerRead.out", generateAvroObjects()); + } + + @Test + public void testReadFromSchema() throws Exception { + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(schema), + "AvroIO.Read.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schema) + .from(avroFile.getPath()), + "AvroIO.Read.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.named("MyRead") + .from(avroFile.getPath()) + .withSchema(schema), + "MyRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.named("MyRead") + .withSchema(schema) + .from(avroFile.getPath()), + "MyRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(schema) + .named("HerRead"), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .named("HerRead") + .withSchema(schema), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schema) + .named("HerRead") + .from(avroFile.getPath()), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schema) + .from(avroFile.getPath()) + .named("HerRead"), + "HerRead.out", generateAvroGenericRecords()); + } + + @Test + public void testReadFromSchemaString() throws Exception { + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(schemaString), + "AvroIO.Read.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schemaString) + .from(avroFile.getPath()), + "AvroIO.Read.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.named("MyRead") + .from(avroFile.getPath()) + .withSchema(schemaString), + "MyRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.named("MyRead") + .withSchema(schemaString) + .from(avroFile.getPath()), + "MyRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .withSchema(schemaString) + .named("HerRead"), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.from(avroFile.getPath()) + .named("HerRead") + .withSchema(schemaString), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schemaString) + .named("HerRead") + .from(avroFile.getPath()), + "HerRead.out", generateAvroGenericRecords()); + runTestRead(AvroIO.Read.withSchema(schemaString) + .from(avroFile.getPath()) + .named("HerRead"), + "HerRead.out", generateAvroGenericRecords()); + } + + void runTestWrite(AvroIO.Write.Bound write, String expectedName) + throws Exception { + User[] users = generateAvroObjects(); + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = p.apply(Create.of(Arrays.asList((T[]) users))) + .setCoder((Coder) AvroCoder.of(User.class)); + PDone output = input.apply(write.withoutSharding()); + EvaluationResults results = p.run(); + assertEquals(expectedName, write.getName()); + + assertThat(readAvroFile(), containsInAnyOrder(users)); + } + + @Test + public void testWriteFromGeneratedClass() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(User.class), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(User.class) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(User.class), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(User.class) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(User.class) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(User.class), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(User.class) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(User.class) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + @Test + public void testWriteFromSchema() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schema), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(schema) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(schema), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(schema) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schema) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(schema), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schema) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schema) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + @Test + public void testWriteFromSchemaString() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schemaString), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(schemaString), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(schemaString) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schemaString) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(schemaString), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + // TODO: for Write only, test withSuffix, withNumShards, + // withShardNameTemplate and withoutSharding. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java new file mode 100644 index 000000000000..863e260282a3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java @@ -0,0 +1,307 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; + +import com.google.api.client.util.Data; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +/** + * Tests for BigQueryIO. + */ +@RunWith(JUnit4.class) +public class BigQueryIOTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private void checkReadObject( + BigQueryIO.Read.Bound bound, String project, String dataset, String table) { + checkReadObjectWithValidate(bound, project, dataset, table, true); + } + + private void checkReadObjectWithValidate( + BigQueryIO.Read.Bound bound, String project, String dataset, String table, boolean validate) { + assertEquals(project, bound.table.getProjectId()); + assertEquals(dataset, bound.table.getDatasetId()); + assertEquals(table, bound.table.getTableId()); + assertEquals(validate, bound.validate); + } + + private void checkWriteObject( + BigQueryIO.Write.Bound bound, String project, String dataset, String table, + TableSchema schema, CreateDisposition createDisposition, + WriteDisposition writeDisposition) { + checkWriteObjectWithValidate( + bound, project, dataset, table, schema, createDisposition, writeDisposition, true); + } + + private void checkWriteObjectWithValidate( + BigQueryIO.Write.Bound bound, String project, String dataset, String table, + TableSchema schema, CreateDisposition createDisposition, + WriteDisposition writeDisposition, boolean validate) { + assertEquals(project, bound.table.getProjectId()); + assertEquals(dataset, bound.table.getDatasetId()); + assertEquals(table, bound.table.getTableId()); + assertEquals(schema, bound.schema); + assertEquals(createDisposition, bound.createDisposition); + assertEquals(writeDisposition, bound.writeDisposition); + assertEquals(validate, bound.validate); + } + + @Before + public void setUp() { + BigQueryOptions options = PipelineOptionsFactory.as(BigQueryOptions.class); + options.setProject("defaultProject"); + } + + @Test + public void testBuildSource() throws IOException { + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable"); + checkReadObject(bound, "foo.com:project", "somedataset", "sometable"); + } + + @Test + public void testBuildSourcewithoutValidation() throws IOException { + // This test just checks that using withoutValidation will not trigger object + // construction errors. + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable").withoutValidation(); + checkReadObjectWithValidate(bound, "foo.com:project", "somedataset", "sometable", false); + } + + @Test + public void testBuildSourceWithDefaultProject() throws IOException { + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("somedataset.sometable"); + checkReadObject(bound, null, "somedataset", "sometable"); + } + + @Test + public void testBuildSourceWithTableReference() throws IOException { + TableReference table = new TableReference() + .setProjectId("foo.com:project") + .setDatasetId("somedataset") + .setTableId("sometable"); + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from(table); + checkReadObject(bound, "foo.com:project", "somedataset", "sometable"); + } + + @Test(expected = IllegalStateException.class) + public void testBuildSourceWithoutTable() throws IOException { + Pipeline p = TestPipeline.create(); + p.apply(BigQueryIO.Read.named("ReadMyTable")); + } + + @Test + public void testBuildSink() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable"); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkwithoutValidation() throws IOException { + // This test just checks that using withoutValidation will not trigger object + // construction errors. + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable").withoutValidation(); + checkWriteObjectWithValidate( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY, false); + } + + @Test + public void testBuildSinkDefaultProject() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("somedataset.sometable"); + checkWriteObject( + bound, null, "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithTableReference() throws IOException { + TableReference table = new TableReference() + .setProjectId("foo.com:project") + .setDatasetId("somedataset") + .setTableId("sometable"); + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to(table); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test(expected = IllegalStateException.class) + public void testBuildSinkWithoutTable() throws IOException { + Pipeline p = TestPipeline.create(); + p.apply(Create.of()).setCoder(TableRowJsonCoder.of()) + .apply(BigQueryIO.Write.named("WriteMyTable")); + } + + @Test + public void testBuildSinkWithSchema() throws IOException { + TableSchema schema = new TableSchema(); + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable").withSchema(schema); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + schema, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithCreateDispositionNever() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withCreateDisposition(CreateDisposition.CREATE_NEVER); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_NEVER, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithCreateDispositionIfNeeded() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithWriteDispositionTruncate() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_TRUNCATE); + } + + @Test + public void testBuildSinkWithWriteDispositionAppend() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_APPEND); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_APPEND); + } + + @Test + public void testBuildSinkWithWriteDispositionEmpty() throws IOException { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_EMPTY); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testTableParsing() { + TableReference ref = BigQueryIO + .parseTableSpec("my-project:data_set.table_name"); + Assert.assertEquals("my-project", ref.getProjectId()); + Assert.assertEquals("data_set", ref.getDatasetId()); + Assert.assertEquals("table_name", ref.getTableId()); + } + + @Test + public void testTableParsing_validPatterns() { + BigQueryIO.parseTableSpec("a123-456:foo_bar.d"); + BigQueryIO.parseTableSpec("a12345:b.c"); + BigQueryIO.parseTableSpec("b12345.c"); + } + + @Test + public void testTableParsing_noProjectId() { + TableReference ref = BigQueryIO + .parseTableSpec("data_set.table_name"); + Assert.assertEquals(null, ref.getProjectId()); + Assert.assertEquals("data_set", ref.getDatasetId()); + Assert.assertEquals("table_name", ref.getTableId()); + } + + @Test + public void testTableParsingError() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("0123456:foo.bar"); + } + + @Test + public void testTableParsingError_2() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("myproject:.bar"); + } + + @Test + public void testTableParsingError_3() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec(":a.b"); + } + + @Test + public void testTableParsingError_slash() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("a\\b12345:c.d"); + } + + // Test that BigQuery's special null placeholder objects can be encoded. + @Test + public void testCoder_nullCell() throws CoderException { + TableRow row = new TableRow(); + row.set("temperature", Data.nullOf(Object.class)); + row.set("max_temperature", Data.nullOf(Object.class)); + + byte[] bytes = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), row); + + TableRow newRow = CoderUtils.decodeFromByteArray(TableRowJsonCoder.of(), bytes); + byte[] newBytes = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), newRow); + + Assert.assertArrayEquals(bytes, newBytes); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java new file mode 100644 index 000000000000..e026c58102da --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.EntityCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; + +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for DatastoreIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class DatastoreIOTest { + + private String host; + private String datasetId; + private Query query; + + /** + * Sets the default dataset ID as "shakespearedataset", + * which contains two kinds of records: "food" and "shakespeare". + * The "food" table contains 10 manually constructed entities, + * The "shakespeare" table contains 172948 entities, + * where each entity represents one line in one play in + * Shakespeare collections (e.g. there are 172948 lines in + * all Shakespeare files). + * + *

The function also sets up the datastore agent by creating + * a Datastore object to access the dataset shakespeareddataset. + * + *

Note that the local server must be started to let the agent + * be created normally. + */ + @Before + public void setUp() { + this.host = "http://localhost:1234"; + this.datasetId = "shakespearedataset"; + + Query.Builder q = Query.newBuilder(); + q.addKindBuilder().setName("shakespeare"); + this.query = q.build(); + } + + /** + * Test for reading one entity from kind "food" + */ + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testBuildRead() throws Exception { + DatastoreIO.Read.Bound readQuery = DatastoreIO.Read + .withHost(this.host) + .from(this.datasetId, this.query); + assertEquals(this.query, readQuery.query); + assertEquals(this.datasetId, readQuery.datasetId); + assertEquals(this.host, readQuery.host); + } + + @Test + public void testBuildReadAlt() throws Exception { + DatastoreIO.Read.Bound readQuery = DatastoreIO.Read + .from(this.datasetId, this.query) + .withHost(this.host); + assertEquals(this.query, readQuery.query); + assertEquals(this.datasetId, readQuery.datasetId); + assertEquals(this.host, readQuery.host); + } + + @Test(expected = IllegalStateException.class) + public void testBuildReadWithoutDatastoreSettingToCatchException() + throws Exception { + // create pipeline and run the pipeline to get result + Pipeline p = DirectPipeline.createForTest(); + p.apply(DatastoreIO.Read.named("ReadDatastore")); + } + + @Test + public void testBuildWrite() throws Exception { + DatastoreIO.Write.Bound write = DatastoreIO.Write + .to(this.datasetId) + .withHost(this.host); + assertEquals(this.host, write.host); + assertEquals(this.datasetId, write.datasetId); + } + + @Test + public void testBuildWriteAlt() throws Exception { + DatastoreIO.Write.Bound write = DatastoreIO.Write + .withHost(this.host) + .to(this.datasetId); + assertEquals(this.host, write.host); + assertEquals(this.datasetId, write.datasetId); + } + + @Test(expected = IllegalStateException.class) + public void testBuildWriteWithoutDatastoreToCatchException() throws Exception { + // create pipeline and run the pipeline to get result + Pipeline p = DirectPipeline.createForTest(); + p.apply(Create.of()).setCoder(EntityCoder.of()) + .apply(DatastoreIO.Write.named("WriteDatastore")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java new file mode 100644 index 000000000000..b6aaf59b51ad --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.TestUtils.INTS_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_INTS_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.testing.TestDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for TextIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class TextIOTest { + + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private static class EmptySeekableByteChannel implements SeekableByteChannel { + public long position() { + return 0L; + } + + public SeekableByteChannel position(long newPosition) { + return this; + } + + public long size() { + return 0L; + } + + public SeekableByteChannel truncate(long size) { + return this; + } + + public int write(ByteBuffer src) { + return 0; + } + + public int read(ByteBuffer dst) { + return 0; + } + + public boolean isOpen() { + return true; + } + + public void close() { } + } + + private GcsUtil buildMockGcsUtil() throws IOException { + GcsUtil mockGcsUtil = Mockito.mock(GcsUtil.class); + + // Any request to open gets a new bogus channel + Mockito + .when(mockGcsUtil.open(Mockito.any(GcsPath.class))) + .thenReturn(new EmptySeekableByteChannel()); + + // Any request for expansion gets a single bogus URL + // after we first run the expansion code (which will generally + // return no results, which causes a crash we aren't testing) + Mockito + .when(mockGcsUtil.expand(Mockito.any(GcsPath.class))) + .thenReturn(Arrays.asList(GcsPath.fromUri("gs://bucket/foo"))); + + return mockGcsUtil; + } + + private TestDataflowPipelineOptions buildTestPipelineOptions() { + TestDataflowPipelineOptions options = + PipelineOptionsFactory.as(TestDataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + return options; + } + + void runTestRead(T[] expected, Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + try (PrintStream writer = new PrintStream(new FileOutputStream(tmpFile))) { + for (T elem : expected) { + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + String line = new String(encodedElem); + writer.println(line); + } + } + + DirectPipeline p = DirectPipeline.createForTest(); + + TextIO.Read.Bound read; + if (coder.equals(StringUtf8Coder.of())) { + TextIO.Read.Bound readStrings = TextIO.Read.from(filename); + // T==String + read = (TextIO.Read.Bound) readStrings; + } else { + read = TextIO.Read.from(filename).withCoder(coder); + } + + PCollection output = p.apply(read); + + EvaluationResults results = p.run(); + + assertThat(results.getPCollection(output), + containsInAnyOrder(expected)); + } + + @Test + public void testReadStrings() throws Exception { + runTestRead(LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testReadEmptyStrings() throws Exception { + runTestRead(NO_LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testReadInts() throws Exception { + runTestRead(INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testReadEmptyInts() throws Exception { + runTestRead(NO_INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testReadNamed() { + Pipeline p = DirectPipeline.createForTest(); + + { + PCollection output1 = + p.apply(TextIO.Read.from("/tmp/file.txt")); + assertEquals("TextIO.Read.out", output1.getName()); + } + + { + PCollection output2 = + p.apply(TextIO.Read.named("MyRead").from("/tmp/file.txt")); + assertEquals("MyRead.out", output2.getName()); + } + + { + PCollection output3 = + p.apply(TextIO.Read.from("/tmp/file.txt").named("HerRead")); + assertEquals("HerRead.out", output3.getName()); + } + } + + void runTestWrite(T[] elems, Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = + p.apply(Create.of(Arrays.asList(elems))).setCoder(coder); + + TextIO.Write.Bound write; + if (coder.equals(StringUtf8Coder.of())) { + TextIO.Write.Bound writeStrings = + TextIO.Write.to(filename).withoutSharding(); + // T==String + write = (TextIO.Write.Bound) writeStrings; + } else { + write = TextIO.Write.to(filename).withCoder(coder).withoutSharding(); + } + + PDone output = input.apply(write); + + EvaluationResults results = p.run(); + + BufferedReader reader = new BufferedReader(new FileReader(tmpFile)); + List actual = new ArrayList<>(); + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + actual.add(line); + } + + String[] expected = new String[elems.length]; + for (int i = 0; i < elems.length; i++) { + T elem = elems[i]; + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + String line = new String(encodedElem); + expected[i] = line; + } + + assertThat(actual, + containsInAnyOrder(expected)); + } + + @Test + public void testWriteStrings() throws Exception { + runTestWrite(LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testWriteEmptyStrings() throws Exception { + runTestWrite(NO_LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testWriteInts() throws Exception { + runTestWrite(INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testWriteEmptyInts() throws Exception { + runTestWrite(NO_INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testWriteSharded() throws IOException { + File outFolder = tmpFolder.newFolder(); + String filename = outFolder.toPath().resolve("output").toString(); + + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = + p.apply(Create.of(Arrays.asList(LINES_ARRAY))) + .setCoder(StringUtf8Coder.of()); + + PDone done = input.apply( + TextIO.Write.to(filename).withNumShards(2).withSuffix(".txt")); + + EvaluationResults results = p.run(); + + String[] files = outFolder.list(); + + assertThat(Arrays.asList(files), + containsInAnyOrder("output-00000-of-00002.txt", + "output-00001-of-00002.txt")); + } + + @Test + public void testWriteNamed() { + Pipeline p = DirectPipeline.createForTest(); + + PCollection input = + p.apply(Create.of(LINES)).setCoder(StringUtf8Coder.of()); + + { + PTransform, PDone> transform1 = + TextIO.Write.to("/tmp/file.txt"); + assertEquals("TextIO.Write", transform1.getName()); + } + + { + PTransform, PDone> transform2 = + TextIO.Write.named("MyWrite").to("/tmp/file.txt"); + assertEquals("MyWrite", transform2.getName()); + } + + { + PTransform, PDone> transform3 = + TextIO.Write.to("/tmp/file.txt").named("HerWrite"); + assertEquals("HerWrite", transform3.getName()); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testUnsupportedFilePattern() throws IOException { + File outFolder = tmpFolder.newFolder(); + String filename = outFolder.toPath().resolve("output@*").toString(); + + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = + p.apply(Create.of(Arrays.asList(LINES_ARRAY))) + .setCoder(StringUtf8Coder.of()); + + PDone done = input.apply(TextIO.Write.to(filename)); + + EvaluationResults results = p.run(); + Assert.fail("Expected failure due to unsupported output pattern"); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests a few corner cases that should not crash. + */ + @Test + public void testGoodWildcards() throws Exception { + TestDataflowPipelineOptions options = buildTestPipelineOptions(); + options.setGcsUtil(buildMockGcsUtil()); + + Pipeline pipeline = Pipeline.create(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/[0-9]")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*baz*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*baz?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/[0-9]baz?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/baz/*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/baz/*wonka*")); + + // Check that running doesn't fail. + pipeline.run(); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "*". + */ + @Test + public void testBadWildcardStar() throws Exception { + Pipeline pipeline = Pipeline.create(buildTestPipelineOptions()); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo*/baz")); + + // Check that running does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("wildcard"); + pipeline.run(); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "?". + */ + @Test + public void testBadWildcardOptional() throws Exception { + Pipeline pipeline = Pipeline.create(buildTestPipelineOptions()); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo?/baz")); + + // Check that running does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("wildcard"); + pipeline.run(); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "[]" based character classes. + */ + @Test + public void testBadWildcardBrackets() throws Exception { + Pipeline pipeline = Pipeline.create(buildTestPipelineOptions()); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo[0-9]/baz")); + + // Check that translation does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("wildcard"); + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc new file mode 100644 index 000000000000..451a19fa12c3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc @@ -0,0 +1,10 @@ +{ + "namespace": "com.google.cloud.dataflow.sdk.io", + "type": "record", + "name": "User", + "fields": [ + { "name": "name", "type": "string"}, + { "name": "favorite_number", "type": ["int", "null"]}, + { "name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java new file mode 100644 index 000000000000..f4d6f0499d44 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.testing.ResetDateTimeProvider; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DataflowPipelineOptions}. */ +@RunWith(JUnit4.class) +public class DataflowPipelineOptionsTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public ResetDateTimeProvider resetDateTimeProviderRule = new ResetDateTimeProvider(); + + @Test + public void testJobNameIsSet() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setJobName("TestJobName"); + assertEquals("TestJobName", options.getJobName()); + } + + @Test + public void testUserNameIsNotSet() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().remove("user.name"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("TestApplication"); + assertEquals("testapplication--1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } + + @Test + public void testAppNameAndUserNameIsTooLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcdeabcdeabcdeabcdeabcdeabcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890123456789012345678901234567890"); + assertEquals("a234567890123456789-abcdeabcd-1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } + + @Test + public void testAppNameIsTooLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890123456789012345678901234567890"); + assertEquals("a2345678901234567890123-abcde-1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } + + @Test + public void testUserNameIsTooLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcdeabcdeabcdeabcdeabcdeabcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890"); + assertEquals("a234567890-abcdeabcdeabcdeabc-1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } + + + @Test + public void testUtf8UserNameAndApplicationNameIsNormalized() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "ði ıntəˈnæʃənəl "); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("fəˈnɛtık əsoʊsiˈeıʃn"); + assertEquals("f00n0t0k00so0si0e00-0i00nt00n-1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java new file mode 100644 index 000000000000..ca1e9502bf97 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java @@ -0,0 +1,502 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** Tests for {@link PipelineOptionsFactory}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsFactoryTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + @Test + public void testCreationFromSystemProperties() { + System.getProperties().putAll(ImmutableMap + .builder() + .put("root_url", "test_root_url") + .put("service_path", "test_service_path") + .put("temp_gcs_directory", + "gs://tap-testing-30lsaafg6g3zudmjbnsdz6wj/unittesting/staging") + .put("service_account_name", "test_service_account_name") + .put("service_account_keyfile", "test_service_account_keyfile") + .put("worker_id", "test_worker_id") + .put("project_id", "test_project_id") + .put("job_id", "test_job_id") + .build()); + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + assertEquals("test_root_url", options.getApiRootUrl()); + assertEquals("test_service_path", options.getDataflowEndpoint()); + assertEquals("gs://tap-testing-30lsaafg6g3zudmjbnsdz6wj/unittesting/staging", + options.getTempLocation()); + assertEquals("test_service_account_name", options.getServiceAccountName()); + assertEquals("test_service_account_keyfile", options.getServiceAccountKeyfile()); + assertEquals("test_worker_id", options.getWorkerId()); + assertEquals("test_project_id", options.getProject()); + assertEquals("test_job_id", options.getJobId()); + } + + @Test + public void testAppNameIsSet() { + ApplicationNameOptions options = PipelineOptionsFactory.as(ApplicationNameOptions.class); + assertEquals(PipelineOptionsFactoryTest.class.getSimpleName(), options.getAppName()); + } + + /** A simple test interface. */ + public static interface TestPipelineOptions extends PipelineOptions { + String getTestPipelineOption(); + void setTestPipelineOption(String value); + } + + @Test + public void testAppNameIsSetWhenUsingAs() { + TestPipelineOptions options = PipelineOptionsFactory.as(TestPipelineOptions.class); + assertEquals(PipelineOptionsFactoryTest.class.getSimpleName(), + options.as(ApplicationNameOptions.class).getAppName()); + } + + @Test + public void testManualRegistration() { + assertFalse(PipelineOptionsFactory.getRegisteredOptions().contains(TestPipelineOptions.class)); + PipelineOptionsFactory.register(TestPipelineOptions.class); + assertTrue(PipelineOptionsFactory.getRegisteredOptions().contains(TestPipelineOptions.class)); + } + + @Test + public void testDefaultRegistration() { + assertTrue(PipelineOptionsFactory.getRegisteredOptions().contains(PipelineOptions.class)); + } + + /** A test interface missing a getter. */ + public static interface MissingGetter extends PipelineOptions { + void setObject(Object value); + } + + @Test + public void testMissingGetterThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected getter for property [object] of type [java.lang.Object] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingGetter]."); + + PipelineOptionsFactory.as(MissingGetter.class); + } + + /** A test interface missing a setter. */ + public static interface MissingSetter extends PipelineOptions { + Object getObject(); + } + + @Test + public void testMissingSetterThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected setter for property [object] of type [java.lang.Object] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter]."); + + PipelineOptionsFactory.as(MissingSetter.class); + } + + /** A test interface representing a composite interface. */ + public static interface CombinedObject extends MissingGetter, MissingSetter { + } + + @Test + public void testHavingSettersGettersFromSeparateInterfacesIsValid() { + PipelineOptionsFactory.as(CombinedObject.class); + } + + /** A test interface which contains a non-bean style method. */ + public static interface ExtraneousMethod extends PipelineOptions { + public String extraneousMethod(int value, String otherValue); + } + + @Test + public void testHavingExtraneousMethodThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Methods [java.lang.String extraneousMethod(int, java.lang.String)] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ExtraneousMethod] " + + "do not conform to being bean properties."); + + PipelineOptionsFactory.as(ExtraneousMethod.class); + } + + /** A test interface which has a conflicting return type with its parent. */ + public static interface ReturnTypeConflict extends CombinedObject { + @Override + String getObject(); + void setObject(String value); + } + + @Test + public void testReturnTypeConflictThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Method [getObject] has multiple definitions [public abstract java.lang.Object " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter" + + ".getObject(), public abstract java.lang.String " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ReturnTypeConflict" + + ".getObject()] with different return types for [" + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ReturnTypeConflict]."); + PipelineOptionsFactory.as(ReturnTypeConflict.class); + } + + /** Test interface that has {@link JsonIgnore @JsonIgnore} on a setter for a property. */ + public static interface SetterWithJsonIgnore extends PipelineOptions { + String getValue(); + @JsonIgnore + void setValue(String value); + } + + @Test + public void testSetterAnnotatedWithJsonIgnore() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected setter for property [value] to not be marked with @JsonIgnore on [com." + + "google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$SetterWithJsonIgnore]"); + PipelineOptionsFactory.as(SetterWithJsonIgnore.class); + } + + /** + * This class is has a conflicting field with {@link CombinedObject} that doesn't have + * {@link JsonIgnore @JsonIgnore}. + */ + public static interface GetterWithJsonIgnore extends PipelineOptions { + @JsonIgnore + Object getObject(); + void setObject(Object value); + } + + @Test + public void testNotAllGettersAnnotatedWithJsonIgnore() throws Exception { + // Initial construction is valid. + GetterWithJsonIgnore options = PipelineOptionsFactory.as(GetterWithJsonIgnore.class); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected getter for property [object] to be marked with @JsonIgnore on all [com." + + "google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter, " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$GetterWithJsonIgnore], " + + "found only on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$GetterWithJsonIgnore]"); + + // When we attempt to convert, we should error at this moment. + options.as(CombinedObject.class); + } + + @Test + public void testAppNameIsNotOverriddenWhenPassedInViaCommandLine() { + ApplicationNameOptions options = PipelineOptionsFactory + .fromArgs(new String[]{ "--appName=testAppName" }) + .as(ApplicationNameOptions.class); + assertEquals("testAppName", options.getAppName()); + } + + @Test + public void testPropertyIsSetOnRegisteredPipelineOptionNotPartOfOriginalInterface() { + PipelineOptions options = PipelineOptionsFactory + .fromArgs(new String[]{ "--project=testProject" }) + .create(); + assertEquals("testProject", options.as(GcpOptions.class).getProject()); + } + + /** A test interface containing all the primitives */ + public static interface Primitives extends PipelineOptions { + boolean getBoolean(); + void setBoolean(boolean value); + char getChar(); + void setChar(char value); + byte getByte(); + void setByte(byte value); + short getShort(); + void setShort(short value); + int getInt(); + void setInt(int value); + long getLong(); + void setLong(long value); + float getFloat(); + void setFloat(float value); + double getDouble(); + void setDouble(double value); + } + + @Test + public void testPrimitives() { + String[] args = new String[] { + "--boolean=true", + "--char=d", + "--byte=12", + "--short=300", + "--int=100000", + "--long=123890123890", + "--float=55.5", + "--double=12.3"}; + + Primitives options = PipelineOptionsFactory.fromArgs(args).as(Primitives.class); + assertTrue(options.getBoolean()); + assertEquals('d', options.getChar()); + assertEquals((byte) 12, options.getByte()); + assertEquals((short) 300, options.getShort()); + assertEquals(100000, options.getInt()); + assertEquals(123890123890L, options.getLong()); + assertEquals(55.5f, options.getFloat(), 0.0f); + assertEquals(12.3, options.getDouble(), 0.0); + } + + @Test + public void testBooleanShorthandArgument() { + String[] args = new String[] {"--boolean"}; + + Primitives options = PipelineOptionsFactory.fromArgs(args).as(Primitives.class); + assertTrue(options.getBoolean()); + } + + /** A test interface containing all supported objects */ + public static interface Objects extends PipelineOptions { + Boolean getBoolean(); + void setBoolean(Boolean value); + Character getChar(); + void setChar(Character value); + Byte getByte(); + void setByte(Byte value); + Short getShort(); + void setShort(Short value); + Integer getInt(); + void setInt(Integer value); + Long getLong(); + void setLong(Long value); + Float getFloat(); + void setFloat(Float value); + Double getDouble(); + void setDouble(Double value); + String getString(); + void setString(String value); + Class getClassValue(); + void setClassValue(Class value); + } + + @Test + public void testObjects() { + String[] args = new String[] { + "--boolean=true", + "--char=d", + "--byte=12", + "--short=300", + "--int=100000", + "--long=123890123890", + "--float=55.5", + "--double=12.3", + "--string=stringValue", + "--classValue=" + PipelineOptionsFactoryTest.class.getName()}; + + Objects options = PipelineOptionsFactory.fromArgs(args).as(Objects.class); + assertTrue(options.getBoolean()); + assertEquals(Character.valueOf('d'), options.getChar()); + assertEquals(Byte.valueOf((byte) 12), options.getByte()); + assertEquals(Short.valueOf((short) 300), options.getShort()); + assertEquals(Integer.valueOf(100000), options.getInt()); + assertEquals(Long.valueOf(123890123890L), options.getLong()); + assertEquals(Float.valueOf(55.5f), options.getFloat(), 0.0f); + assertEquals(Double.valueOf(12.3), options.getDouble(), 0.0); + assertEquals("stringValue", options.getString()); + assertEquals(PipelineOptionsFactoryTest.class, options.getClassValue()); + } + + @Test + public void testMissingArgument() { + String[] args = new String[] {}; + + Objects options = PipelineOptionsFactory.fromArgs(args).as(Objects.class); + assertNull(options.getString()); + } + + /** A test interface containing all supported array return types */ + public static interface Arrays extends PipelineOptions { + boolean[] getBoolean(); + void setBoolean(boolean[] value); + char[] getChar(); + void setChar(char[] value); + short[] getShort(); + void setShort(short[] value); + int[] getInt(); + void setInt(int[] value); + long[] getLong(); + void setLong(long[] value); + float[] getFloat(); + void setFloat(float[] value); + double[] getDouble(); + void setDouble(double[] value); + String[] getString(); + void setString(String[] value); + Class[] getClassValue(); + void setClassValue(Class[] value); + } + + @Test + public void testArrays() { + String[] args = new String[] { + "--boolean=true", + "--boolean=true", + "--boolean=false", + "--char=d", + "--char=e", + "--char=f", + "--short=300", + "--short=301", + "--short=302", + "--int=100000", + "--int=100001", + "--int=100002", + "--long=123890123890", + "--long=123890123891", + "--long=123890123892", + "--float=55.5", + "--float=55.6", + "--float=55.7", + "--double=12.3", + "--double=12.4", + "--double=12.5", + "--string=stringValue1", + "--string=stringValue2", + "--string=stringValue3", + "--classValue=" + PipelineOptionsFactory.class.getName(), + "--classValue=" + PipelineOptionsFactoryTest.class.getName()}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + boolean[] bools = options.getBoolean(); + assertTrue(bools[0] && bools[1] && !bools[2]); + assertArrayEquals(new char[] {'d', 'e', 'f'}, options.getChar()); + assertArrayEquals(new short[] {300, 301, 302}, options.getShort()); + assertArrayEquals(new int[] {100000, 100001, 100002}, options.getInt()); + assertArrayEquals(new long[] {123890123890L, 123890123891L, 123890123892L}, options.getLong()); + assertArrayEquals(new float[] {55.5f, 55.6f, 55.7f}, options.getFloat(), 0.0f); + assertArrayEquals(new double[] {12.3, 12.4, 12.5}, options.getDouble(), 0.0); + assertArrayEquals(new String[] {"stringValue1", "stringValue2", "stringValue3"}, + options.getString()); + assertArrayEquals(new Class[] {PipelineOptionsFactory.class, + PipelineOptionsFactoryTest.class}, + options.getClassValue()); + } + + @Test + public void testOutOfOrderArrays() { + String[] args = new String[] { + "--char=d", + "--boolean=true", + "--boolean=true", + "--char=e", + "--char=f", + "--boolean=false"}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + boolean[] bools = options.getBoolean(); + assertTrue(bools[0] && bools[1] && !bools[2]); + assertArrayEquals(new char[] {'d', 'e', 'f'}, options.getChar()); + } + + /** A test interface containing all supported List return types */ + public static interface Lists extends PipelineOptions { + List getString(); + void setString(List value); + } + + @Test + public void testList() { + String[] args = + new String[] {"--string=stringValue1", "--string=stringValue2", "--string=stringValue3"}; + + Lists options = PipelineOptionsFactory.fromArgs(args).as(Lists.class); + assertEquals(ImmutableList.of("stringValue1", "stringValue2", "stringValue3"), + options.getString()); + } + + @Test + public void testListShorthand() { + String[] args = new String[] {"--string=stringValue1,stringValue2,stringValue3"}; + + Lists options = PipelineOptionsFactory.fromArgs(args).as(Lists.class); + assertEquals(ImmutableList.of("stringValue1", "stringValue2", "stringValue3"), + options.getString()); + } + + @Test + public void testMixedShorthandAndLongStyleList() { + String[] args = new String[] { + "--char=d", + "--char=e", + "--char=f", + "--char=g,h,i", + "--char=j", + "--char=k", + "--char=l", + "--char=m,n,o"}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + assertArrayEquals(new char[] {'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'}, + options.getChar()); + } + + @Test + public void testSetASingularAttributeUsingAListThrowsAnError() { + String[] args = new String[] { + "--diskSizeGb=100", + "--diskSizeGb=200"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("expected one element but was"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testSettingRunner() { + String[] args = new String[] {"--runner=BlockingDataflowPipelineRunner"}; + + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(BlockingDataflowPipelineRunner.class, options.getRunner()); + } + + @Test + public void testSettingUnknownRunner() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Unknown 'runner' specified UnknownRunner, supported pipeline " + + "runners [DirectPipelineRunner, DataflowPipelineRunner, BlockingDataflowPipelineRunner]"); + String[] args = new String[] {"--runner=UnknownRunner"}; + + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); + options.getRunner(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java new file mode 100644 index 000000000000..9db6a6b75422 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link PipelineOptions}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsTest { + /** Interface used for testing that {@link PipelineOptions#as(Class)} functions */ + public static interface TestOptions extends PipelineOptions { + } + + @Test + public void testDynamicAs() { + TestOptions options = PipelineOptionsFactory.create().as(TestOptions.class); + assertNotNull(options); + } + + @Test + public void testDefaultRunnerIsSet() { + assertEquals(DirectPipelineRunner.class, PipelineOptionsFactory.create().getRunner()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java new file mode 100644 index 000000000000..e0decb9f9225 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PipelineOptionsValidator}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsValidatorTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + /** A test interface with an {@link Validation.Required} annotation. */ + public static interface Required extends PipelineOptions { + @Validation.Required + public String getObject(); + public void setObject(String value); + } + + @Test + public void testWhenRequiredOptionIsSet() { + Required required = PipelineOptionsFactory.as(Required.class); + required.setObject("blah"); + PipelineOptionsValidator.validate(Required.class, required); + } + + @Test + public void testWhenRequiredOptionIsSetAndCleared() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Expected non-null property to be set for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject()]."); + + Required required = PipelineOptionsFactory.as(Required.class); + required.setObject("blah"); + required.setObject(null); + PipelineOptionsValidator.validate(Required.class, required); + } + + @Test + public void testWhenRequiredOptionIsNeverSet() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Expected non-null property to be set for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject()]."); + + Required required = PipelineOptionsFactory.as(Required.class); + PipelineOptionsValidator.validate(Required.class, required); + } + + /** A test interface which overrides the parents method. */ + public static interface SubClassValidation extends Required { + @Override + public String getObject(); + @Override + public void setObject(String value); + } + + @Test + public void testValidationOnOverriddenMethods() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Expected non-null property to be set for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject()]."); + + SubClassValidation required = PipelineOptionsFactory.as(SubClassValidation.class); + PipelineOptionsValidator.validate(Required.class, required); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java new file mode 100644 index 000000000000..b9b07e8626b1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java @@ -0,0 +1,625 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** Tests for {@link ProxyInvocationHandler}. */ +@RunWith(JUnit4.class) +public class ProxyInvocationHandlerTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + /** A test interface with some primitives and objects. */ + public static interface Simple extends PipelineOptions { + boolean isOptionEnabled(); + void setOptionEnabled(boolean value); + int getPrimitive(); + void setPrimitive(int value); + String getString(); + void setString(String value); + } + + @Test + public void testPropertySettingAndGetting() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("OBJECT"); + proxy.setOptionEnabled(true); + proxy.setPrimitive(4); + assertEquals("OBJECT", proxy.getString()); + assertTrue(proxy.isOptionEnabled()); + assertEquals(4, proxy.getPrimitive()); + } + + /** A test interface containing all the JLS default values. */ + public static interface JLSDefaults extends PipelineOptions { + boolean getBoolean(); + void setBoolean(boolean value); + char getChar(); + void setChar(char value); + byte getByte(); + void setByte(byte value); + short getShort(); + void setShort(short value); + int getInt(); + void setInt(int value); + long getLong(); + void setLong(long value); + float getFloat(); + void setFloat(float value); + double getDouble(); + void setDouble(double value); + Object getObject(); + void setObject(Object value); + } + + @Test + public void testGettingJLSDefaults() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + JLSDefaults proxy = handler.as(JLSDefaults.class); + assertFalse(proxy.getBoolean()); + assertEquals('\0', proxy.getChar()); + assertEquals((byte) 0, proxy.getByte()); + assertEquals((short) 0, proxy.getShort()); + assertEquals(0, proxy.getInt()); + assertEquals(0L, proxy.getLong()); + assertEquals(0f, proxy.getFloat(), 0f); + assertEquals(0d, proxy.getDouble(), 0d); + assertNull(proxy.getObject()); + } + + /** A {@link DefaultValueFactory} which is used for testing. */ + public static class TestOptionFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + return "testOptionFactory"; + } + } + + /** A test interface containing all the {@link Default} annotations. */ + public static interface DefaultAnnotations extends PipelineOptions { + @Default.Boolean(true) + boolean getBoolean(); + void setBoolean(boolean value); + @Default.Character('a') + char getChar(); + void setChar(char value); + @Default.Byte((byte) 4) + byte getByte(); + void setByte(byte value); + @Default.Short((short) 5) + short getShort(); + void setShort(short value); + @Default.Integer(6) + int getInt(); + void setInt(int value); + @Default.Long(7L) + long getLong(); + void setLong(long value); + @Default.Float(8f) + float getFloat(); + void setFloat(float value); + @Default.Double(9d) + double getDouble(); + void setDouble(double value); + @Default.String("testString") + String getString(); + void setString(String value); + @Default.Class(DefaultAnnotations.class) + Class getClassOption(); + void setClassOption(Class value); + @Default.InstanceFactory(TestOptionFactory.class) + String getComplex(); + void setComplex(String value); + } + + @Test + public void testAnnotationDefaults() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + DefaultAnnotations proxy = handler.as(DefaultAnnotations.class); + assertTrue(proxy.getBoolean()); + assertEquals('a', proxy.getChar()); + assertEquals((byte) 4, proxy.getByte()); + assertEquals((short) 5, proxy.getShort()); + assertEquals(6, proxy.getInt()); + assertEquals(7, proxy.getLong()); + assertEquals(8f, proxy.getFloat(), 0f); + assertEquals(9d, proxy.getDouble(), 0d); + assertEquals("testString", proxy.getString()); + assertEquals(DefaultAnnotations.class, proxy.getClassOption()); + assertEquals("testOptionFactory", proxy.getComplex()); + } + + @Test + public void testEquals() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + JLSDefaults sameAsProxy = proxy.as(JLSDefaults.class); + ProxyInvocationHandler handler2 = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy2 = handler2.as(Simple.class); + JLSDefaults sameAsProxy2 = proxy2.as(JLSDefaults.class); + assertTrue(handler.equals(proxy)); + assertTrue(proxy.equals(proxy)); + assertTrue(proxy.equals(sameAsProxy)); + assertFalse(handler.equals(handler2)); + assertFalse(proxy.equals(proxy2)); + assertFalse(proxy.equals(sameAsProxy2)); + } + + @Test + public void testHashCode() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + JLSDefaults sameAsProxy = proxy.as(JLSDefaults.class); + ProxyInvocationHandler handler2 = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy2 = handler.as(Simple.class); + JLSDefaults sameAsProxy2 = proxy.as(JLSDefaults.class); + assertTrue(handler.hashCode() == proxy.hashCode()); + assertTrue(proxy.hashCode() == sameAsProxy.hashCode()); + assertFalse(handler.hashCode() != handler2.hashCode()); + assertFalse(proxy.hashCode() != proxy2.hashCode()); + assertFalse(proxy.hashCode() != sameAsProxy2.hashCode()); + } + + @Test + public void testToString() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("stringValue"); + DefaultAnnotations proxy2 = proxy.as(DefaultAnnotations.class); + proxy2.setLong(57L); + assertEquals("Current Settings:\n" + + " long: 57\n" + + " string: stringValue\n", + proxy.toString()); + } + + /** A test interface containing an unknown method. */ + public static interface UnknownMethod { + void unknownMethod(); + } + + @Test + public void testInvokeWithUnknownMethod() throws Exception { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Unknown method [public abstract void com.google.cloud." + + "dataflow.sdk.options.ProxyInvocationHandlerTest$UnknownMethod.unknownMethod()] " + + "invoked with args [null]."); + + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + handler.invoke(handler, UnknownMethod.class.getMethod("unknownMethod"), null); + } + + /** A test interface which extends another interface. */ + public static interface SubClass extends Simple { + String getExtended(); + void setExtended(String value); + } + + @Test + public void testSubClassStoresSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + assertEquals("parentValue", extended.getString()); + } + + @Test + public void testUpCastRetainsSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + Simple simple = extended.as(Simple.class); + assertEquals("parentValue", simple.getString()); + } + + @Test + public void testUpCastRetainsSubClassValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setExtended("subClassValue"); + SubClass extended2 = extended.as(Simple.class).as(SubClass.class); + assertEquals("subClassValue", extended2.getExtended()); + } + + /** A test interface which is a sibling to {@link SubClass}. */ + public static interface Sibling extends Simple { + String getSibling(); + void setSibling(String value); + } + + @Test + public void testAsSiblingRetainsSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + Sibling sibling = extended.as(Sibling.class); + assertEquals("parentValue", sibling.getString()); + } + + /** A test interface which has the same methods as the parent. */ + public static interface MethodConflict extends Simple { + @Override + String getString(); + @Override + void setString(String value); + } + + @Test + public void testMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + MethodConflict methodConflict = handler.as(MethodConflict.class); + + methodConflict.setString("conflictValue"); + assertEquals("conflictValue", methodConflict.getString()); + assertEquals("conflictValue", methodConflict.as(Simple.class).getString()); + } + + /** A test interface which has the same methods as its parent and grandparent. */ + public static interface DeepMethodConflict extends MethodConflict { + @Override + String getString(); + @Override + void setString(String value); + @Override + int getPrimitive(); + @Override + void setPrimitive(int value); + } + + @Test + public void testDeepMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + DeepMethodConflict deepMethodConflict = handler.as(DeepMethodConflict.class); + + // Tests overriding an already overridden method + deepMethodConflict.setString("conflictValue"); + assertEquals("conflictValue", deepMethodConflict.getString()); + assertEquals("conflictValue", deepMethodConflict.as(MethodConflict.class).getString()); + assertEquals("conflictValue", deepMethodConflict.as(Simple.class).getString()); + + // Tests overriding a method from an ancestor class + deepMethodConflict.setPrimitive(5); + assertEquals(5, deepMethodConflict.getPrimitive()); + assertEquals(5, deepMethodConflict.as(MethodConflict.class).getPrimitive()); + assertEquals(5, deepMethodConflict.as(Simple.class).getPrimitive()); + } + + /** A test interface which shares the same methods as {@link Sibling}. */ + public static interface SimpleSibling extends PipelineOptions { + String getString(); + void setString(String value); + } + + @Test + public void testDisjointSiblingsShareValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SimpleSibling proxy = handler.as(SimpleSibling.class); + proxy.setString("siblingValue"); + assertEquals("siblingValue", proxy.getString()); + assertEquals("siblingValue", proxy.as(Simple.class).getString()); + } + + /** A test interface which joins two sibling interfaces which have conflicting methods. */ + public static interface SiblingMethodConflict extends Simple, SimpleSibling { + } + + @Test + public void testSiblingMethodConflict() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SiblingMethodConflict siblingMethodConflict = handler.as(SiblingMethodConflict.class); + siblingMethodConflict.setString("siblingValue"); + assertEquals("siblingValue", siblingMethodConflict.getString()); + assertEquals("siblingValue", siblingMethodConflict.as(Simple.class).getString()); + assertEquals("siblingValue", siblingMethodConflict.as(SimpleSibling.class).getString()); + } + + /** A test interface which has only the getter and only a setter overriden. */ + public static interface PartialMethodConflict extends Simple { + @Override + String getString(); + @Override + void setPrimitive(int value); + } + + @Test + public void testPartialMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + PartialMethodConflict partialMethodConflict = handler.as(PartialMethodConflict.class); + + // Tests overriding a getter property which is only partially bound + partialMethodConflict.setString("conflictValue"); + assertEquals("conflictValue", partialMethodConflict.getString()); + assertEquals("conflictValue", partialMethodConflict.as(Simple.class).getString()); + + // Tests overriding a setter property which is only partially bound + partialMethodConflict.setPrimitive(5); + assertEquals(5, partialMethodConflict.getPrimitive()); + assertEquals(5, partialMethodConflict.as(Simple.class).getPrimitive()); + } + + @Test + public void testJsonConversionForDefault() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + assertNotNull(serializeDeserialize(PipelineOptions.class, options)); + } + + /** Test interface for JSON conversion of simple types */ + private static interface SimpleTypes extends PipelineOptions { + int getInteger(); + void setInteger(int value); + String getString(); + void setString(String value); + } + + @Test + public void testJsonConversionForSimpleTypes() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setString("TestValue"); + options.setInteger(5); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + assertEquals(5, options2.getInteger()); + assertEquals("TestValue", options2.getString()); + } + + @Test + public void testJsonConversionOfAJsonConvertedType() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setString("TestValue"); + options.setInteger(5); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, + serializeDeserialize(SimpleTypes.class, options)); + assertEquals(5, options2.getInteger()); + assertEquals("TestValue", options2.getString()); + } + + @Test + public void testJsonConversionForPartiallySerializedValues() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setInteger(5); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + options2.setString("TestValue"); + SimpleTypes options3 = serializeDeserialize(SimpleTypes.class, options2); + assertEquals(5, options3.getInteger()); + assertEquals("TestValue", options3.getString()); + } + + @Test + public void testJsonConversionForOverriddenSerializedValues() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setInteger(-5); + options.setString("NeedsToBeOverridden"); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + options2.setInteger(5); + options2.setString("TestValue"); + SimpleTypes options3 = serializeDeserialize(SimpleTypes.class, options2); + assertEquals(5, options3.getInteger()); + assertEquals("TestValue", options3.getString()); + } + + /** Test interface for JSON conversion of container types */ + private static interface ContainerTypes extends PipelineOptions { + List getList(); + void setList(List values); + Map getMap(); + void setMap(Map values); + Set getSet(); + void setSet(Set values); + } + + @Test + public void testJsonConversionForContainerTypes() throws Exception { + List list = ImmutableList.of("a", "b", "c"); + Map map = ImmutableMap.of("d", "x", "e", "y", "f", "z"); + Set set = ImmutableSet.of("g", "h", "i"); + ContainerTypes options = PipelineOptionsFactory.as(ContainerTypes.class); + options.setList(list); + options.setMap(map); + options.setSet(set); + ContainerTypes options2 = serializeDeserialize(ContainerTypes.class, options); + assertEquals(list, options2.getList()); + assertEquals(map, options2.getMap()); + assertEquals(set, options2.getSet()); + } + + /** Test interface for conversion of inner types */ + private static class InnerType { + public double doubleField; + + static InnerType of(double value) { + InnerType rval = new InnerType(); + rval.doubleField = value; + return rval; + } + + @Override + public boolean equals(Object obj) { + return obj != null + && getClass().equals(obj.getClass()) + && Objects.equals(doubleField, ((InnerType) obj).doubleField); + } + } + + /** Test interface for conversion of generics and inner types */ + private static class ComplexType { + public String stringField; + public Integer intField; + public List genericType; + public InnerType innerType; + + @Override + public boolean equals(Object obj) { + return obj != null + && getClass().equals(obj.getClass()) + && Objects.equals(stringField, ((ComplexType) obj).stringField) + && Objects.equals(intField, ((ComplexType) obj).intField) + && Objects.equals(genericType, ((ComplexType) obj).genericType) + && Objects.equals(innerType, ((ComplexType) obj).innerType); + } + } + + private static interface ComplexTypes extends PipelineOptions { + ComplexType getComplexType(); + void setComplexType(ComplexType value); + } + + @Test + public void testJsonConversionForComplexType() throws Exception { + ComplexType complexType = new ComplexType(); + complexType.stringField = "stringField"; + complexType.intField = 12; + complexType.innerType = InnerType.of(12); + complexType.genericType = ImmutableList.of(InnerType.of(16234), InnerType.of(24)); + + ComplexTypes options = PipelineOptionsFactory.as(ComplexTypes.class); + options.setComplexType(complexType); + ComplexTypes options2 = serializeDeserialize(ComplexTypes.class, options); + assertEquals(complexType, options2.getComplexType()); + } + + /** Test interface for testing ignored properties during serialization. */ + private static interface IgnoredProperty extends PipelineOptions { + @JsonIgnore + String getValue(); + void setValue(String value); + } + + @Test + public void testJsonConversionOfIgnoredProperty() throws Exception { + IgnoredProperty options = PipelineOptionsFactory.as(IgnoredProperty.class); + options.setValue("TestValue"); + + IgnoredProperty options2 = serializeDeserialize(IgnoredProperty.class, options); + assertNull(options2.getValue()); + } + + /** Test class which is not serializable by Jackson. */ + public static class NotSerializable { + private String value; + public NotSerializable(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** Test interface containing a class which is not serializable by Jackson. */ + private static interface NotSerializableProperty extends PipelineOptions { + NotSerializable getValue(); + void setValue(NotSerializable value); + } + + @Test(expected = JsonMappingException.class) + public void testJsonConversionOfNotSerializableProperty() throws Exception { + NotSerializableProperty options = PipelineOptionsFactory.as(NotSerializableProperty.class); + options.setValue(new NotSerializable("TestString")); + + serializeDeserialize(NotSerializableProperty.class, options); + } + + /** + * Test interface which has {@link JsonIgnore @JsonIgnore} on a property that Jackson + * can't serialize. + */ + private static interface IgnoredNotSerializableProperty extends PipelineOptions { + @JsonIgnore + NotSerializable getValue(); + void setValue(NotSerializable value); + } + + @Test + public void testJsonConversionOfIgnoredNotSerializableProperty() throws Exception { + IgnoredNotSerializableProperty options = + PipelineOptionsFactory.as(IgnoredNotSerializableProperty.class); + options.setValue(new NotSerializable("TestString")); + + IgnoredNotSerializableProperty options2 = + serializeDeserialize(IgnoredNotSerializableProperty.class, options); + assertNull(options2.getValue()); + } + + /** Test class which is only serializable by Jackson with the added metadata. */ + public static class SerializableWithMetadata { + private String value; + public SerializableWithMetadata(@JsonProperty("value") String value) { + this.value = value; + } + + @JsonProperty("value") + public String getValue() { + return value; + } + } + + /** + * Test interface containing a property which is only serializable by Jackson with + * the additional metadata. + */ + private static interface SerializableWithMetadataProperty extends PipelineOptions { + SerializableWithMetadata getValue(); + void setValue(SerializableWithMetadata value); + } + + @Test + public void testJsonConversionOfSerializableWithMetadataProperty() throws Exception { + SerializableWithMetadataProperty options = + PipelineOptionsFactory.as(SerializableWithMetadataProperty.class); + options.setValue(new SerializableWithMetadata("TestString")); + + SerializableWithMetadataProperty options2 = + serializeDeserialize(SerializableWithMetadataProperty.class, options); + assertEquals("TestString", options2.getValue().getValue()); + } + + private T serializeDeserialize(Class kls, PipelineOptions options) + throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String value = mapper.writeValueAsString(options); + return mapper.readValue(value, PipelineOptions.class).as(kls); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java new file mode 100644 index 000000000000..398326e8a385 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobState; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.util.Date; +import java.util.concurrent.TimeUnit; + +/** + * Tests for BlockingDataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class BlockingDataflowPipelineRunnerTest { + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(BlockingDataflowPipelineRunner.class); + + // This class mocks a call to DataflowPipelineJob.waitToFinish(): + // it blocks the thread to simulate waiting, + // and releases the blocking once signaled + static class MockWaitToFinish implements Answer { + NotificationHelper jobCompleted = new NotificationHelper(); + + public Object answer(InvocationOnMock invocation) throws InterruptedException { + System.out.println("MockWaitToFinish.answer(): Wait for signaling job completion."); + assertTrue("Test did not receive mock job completion signal", + jobCompleted.waitTillSet(10000)); + + System.out.println("MockWaitToFinish.answer(): job completed."); + return JobState.DONE; + } + + public void signalJobComplete() { + jobCompleted.set(); + } + } + + // Mini helper class for wait-notify + static class NotificationHelper { + private boolean isSet = false; + + public synchronized void set() { + isSet = true; + notifyAll(); + } + + public synchronized boolean check() { + return isSet; + } + + public synchronized boolean waitTillSet(long timeout) throws InterruptedException { + long remainingTimeout = timeout; + long startTime = new Date().getTime(); + while (!isSet && remainingTimeout > 0) { + wait(remainingTimeout); + remainingTimeout = timeout - (new Date().getTime() - startTime); + } + + return isSet; + } + } + + @Test + public void testJobWaitComplete() throws IOException, InterruptedException { + expectedLogs.expectInfo("Job finished with status DONE"); + + DataflowPipelineRunner mockDataflowPipelineRunner = mock(DataflowPipelineRunner.class); + DataflowPipelineJob mockJob = mock(DataflowPipelineJob.class); + MockWaitToFinish mockWait = new MockWaitToFinish(); + + when(mockJob.waitToFinish( + anyLong(), isA(TimeUnit.class), isA(MonitoringUtil.JobMessagesHandler.class))) + .thenAnswer(mockWait); + when(mockDataflowPipelineRunner.run(isA(Pipeline.class))).thenReturn(mockJob); + + // Construct a BlockingDataflowPipelineRunner with mockDataflowPipelineRunner inside + final BlockingDataflowPipelineRunner blockingRunner = + new BlockingDataflowPipelineRunner( + mockDataflowPipelineRunner, + new MonitoringUtil.PrintHandler(System.out)); + + final NotificationHelper executionStarted = new NotificationHelper(); + final NotificationHelper jobCompleted = new NotificationHelper(); + + new Thread() { + public void run() { + executionStarted.set(); + + // Run on an empty test pipeline. + blockingRunner.run(DirectPipeline.createForTest()); + + // Test following code is not reached till mock job completion signal. + jobCompleted.set(); + } + }.start(); + + assertTrue("'executionStarted' event not set till timeout.", + executionStarted.waitTillSet(2000)); + assertFalse("Code after job completion should not be reached before mock signal.", + jobCompleted.check()); + + mockWait.signalJobComplete(); + assertTrue("run() should return after job completion is mocked.", + jobCompleted.waitTillSet(2000)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java new file mode 100644 index 000000000000..30697deecc0f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.Job; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobState; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** + * Tests for DataflowPipelineJob. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineJobTest { + private static final String PROJECT_ID = "someProject"; + private static final String JOB_ID = "1234"; + + @Test + public void testWaitToFinish() throws IOException, InterruptedException { + Dataflow mockWorkflowClient = mock(Dataflow.class); + Dataflow.V1b3 mockV1b3 = mock(Dataflow.V1b3.class); + Dataflow.V1b3.Projects mockProjects = mock(Dataflow.V1b3.Projects.class); + Dataflow.V1b3.Projects.Jobs mockJobs = mock(Dataflow.V1b3.Projects.Jobs.class); + Dataflow.V1b3.Projects.Jobs.Get statusRequest = mock(Dataflow.V1b3.Projects.Jobs.Get.class); + + Job statusResponse = new Job(); + statusResponse.setCurrentState(JobState.DONE.getStateName()); + + when(mockWorkflowClient.v1b3()).thenReturn(mockV1b3); + when(mockV1b3.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))) + .thenReturn(statusRequest); + when(statusRequest.execute()).thenReturn(statusResponse); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient); + + JobState state = job.waitToFinish(1, TimeUnit.MINUTES, null); + assertEquals(JobState.DONE, state); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java new file mode 100644 index 000000000000..7995445c9869 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java @@ -0,0 +1,501 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.DataflowReleaseInfo; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.PackageUtil; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.io.File; +import java.io.IOException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for DataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineRunnerTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + // Asserts that the given Job has all expected fields set. + private static void assertValidJob(Job job) { + assertNull(job.getId()); + assertNull(job.getCurrentState()); + } + + private DataflowPipeline buildDataflowPipeline(DataflowPipelineOptions options) { + DataflowPipeline p = DataflowPipeline.create(options); + + p.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/object")); + + return p; + } + + private static Dataflow buildMockDataflow( + final ArgumentCaptor jobCaptor) throws IOException { + Dataflow mockDataflowClient = mock(Dataflow.class); + Dataflow.V1b3 mockV1b3 = mock(Dataflow.V1b3.class); + Dataflow.V1b3.Projects mockProjects = mock(Dataflow.V1b3.Projects.class); + Dataflow.V1b3.Projects.Jobs mockJobs = mock(Dataflow.V1b3.Projects.Jobs.class); + Dataflow.V1b3.Projects.Jobs.Create mockRequest = + mock(Dataflow.V1b3.Projects.Jobs.Create.class); + + when(mockDataflowClient.v1b3()).thenReturn(mockV1b3); + when(mockV1b3.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + when(mockJobs.create(eq("someProject"), jobCaptor.capture())) + .thenReturn(mockRequest); + + Job resultJob = new Job(); + resultJob.setId("newid"); + when(mockRequest.execute()).thenReturn(resultJob); + return mockDataflowClient; + } + + private GcsUtil buildMockGcsUtil() throws IOException { + GcsUtil mockGcsUtil = mock(GcsUtil.class); + when(mockGcsUtil.create( + any(GcsPath.class), anyString())) + .thenReturn(FileChannel.open( + Files.createTempFile("channel-", ".tmp"), + StandardOpenOption.CREATE, StandardOpenOption.DELETE_ON_CLOSE)); + return mockGcsUtil; + } + + private DataflowPipelineOptions buildPipelineOptions( + ArgumentCaptor jobCaptor) throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setProject("someProject"); + options.setTempLocation(DataflowPipelineRunner.verifyGcsPath( + GcsPath.fromComponents("somebucket", "some/path")).toString()); + // Set FILES_PROPERTY to empty to prevent a default value calculated from classpath. + options.setFilesToStage(new LinkedList()); + options.setDataflowClient(buildMockDataflow(jobCaptor)); + options.setGcsUtil(buildMockGcsUtil()); + options.setGcpCredential(new TestCredential()); + return options; + } + + @Test + public void testRun() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + DataflowPipeline p = buildDataflowPipeline(options); + DataflowPipelineJob job = p.run(); + assertEquals("newid", job.getJobId()); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testRunWithFiles() throws IOException { + // Test that the function DataflowPipelineRunner.stageFiles works as + // expected. + GcsUtil mockGcsUtil = buildMockGcsUtil(); + final GcsPath gcsStaging = + GcsPath.fromComponents("somebucket", "some/path"); + final GcsPath gcsTemp = + GcsPath.fromComponents("somebucket", "some/temp/path"); + final String cloudDataflowDataset = "somedataset"; + + // Create some temporary files. + File temp1 = File.createTempFile("DataflowPipelineRunnerTest", "txt"); + temp1.deleteOnExit(); + File temp2 = File.createTempFile("DataflowPipelineRunnerTest2", "txt"); + temp2.deleteOnExit(); + + DataflowPackage expectedPackage1 = PackageUtil.createPackage( + temp1.getAbsolutePath(), gcsStaging, null); + + String overridePackageName = "alias.txt"; + DataflowPackage expectedPackage2 = PackageUtil.createPackage( + temp2.getAbsolutePath(), gcsStaging, overridePackageName); + + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setFilesToStage(ImmutableList.of( + temp1.getAbsolutePath(), + overridePackageName + "=" + temp2.getAbsolutePath())); + options.setStagingLocation(gcsStaging.toString()); + options.setTempLocation(gcsTemp.toString()); + options.setTempDatasetId(cloudDataflowDataset); + options.setProject("someProject"); + options.setJobName("job"); + options.setDataflowClient(buildMockDataflow(jobCaptor)); + options.setGcsUtil(mockGcsUtil); + options.setGcpCredential(new TestCredential()); + + DataflowPipeline p = buildDataflowPipeline(options); + + DataflowPipelineJob job = p.run(); + assertEquals("newid", job.getJobId()); + + Job workflowJob = jobCaptor.getValue(); + assertValidJob(workflowJob); + + assertEquals( + 2, + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().size()); + DataflowPackage workflowPackage1 = + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().get(0); + assertEquals(expectedPackage1.getName(), workflowPackage1.getName()); + assertEquals(expectedPackage1.getLocation(), workflowPackage1.getLocation()); + DataflowPackage workflowPackage2 = + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().get(1); + assertEquals(expectedPackage2.getName(), workflowPackage2.getName()); + assertEquals(expectedPackage2.getLocation(), workflowPackage2.getLocation()); + + assertEquals( + gcsTemp.toResourceName(), + workflowJob.getEnvironment().getTempStoragePrefix()); + assertEquals( + cloudDataflowDataset, + workflowJob.getEnvironment().getDataset()); + assertEquals( + DataflowReleaseInfo.getReleaseInfo().getName(), + workflowJob.getEnvironment().getUserAgent().get("name")); + assertEquals( + DataflowReleaseInfo.getReleaseInfo().getVersion(), + workflowJob.getEnvironment().getUserAgent().get("version")); + } + + @Test + public void runWithDefaultFilesToStage() throws Exception { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setFilesToStage(null); + DataflowPipelineRunner.fromOptions(options); + assertTrue(!options.getFilesToStage().isEmpty()); + } + + @Test + public void detectClassPathResourceWithFileResources() throws Exception { + String path = "/tmp/file"; + String path2 = "/tmp/file2"; + URLClassLoader classLoader = new URLClassLoader(new URL[]{ + new URL("file://" + path), + new URL("file://" + path2) + }); + + assertEquals(ImmutableList.of(path, path2), + DataflowPipelineRunner.detectClassPathResourcesToStage(classLoader)); + } + + @Test + public void detectClassPathResourcesWithUnsupportedClassLoader() { + ClassLoader mockClassLoader = Mockito.mock(ClassLoader.class); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unable to use ClassLoader to detect classpath elements."); + + DataflowPipelineRunner.detectClassPathResourcesToStage(mockClassLoader); + } + + @Test + public void detectClassPathResourceWithNonFileResources() throws Exception { + String url = "http://www.google.com/all-the-secrets.jar"; + URLClassLoader classLoader = new URLClassLoader(new URL[]{ + new URL(url) + }); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unable to convert url (" + url + ") to file."); + + DataflowPipelineRunner.detectClassPathResourcesToStage(classLoader); + } + + @Test + public void testGcsStagingLocationInitialization() { + // Test that the staging location is initialized correctly. + GcsPath gcsTemp = GcsPath.fromComponents("somebucket", + "some/temp/path"); + + // Set temp location (required), and check that staging location is set. + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setTempLocation(gcsTemp.toString()); + options.setProject("testProject"); + options.setGcpCredential(new TestCredential()); + DataflowPipelineRunner.fromOptions(options); + + assertNotNull(options.getStagingLocation()); + } + + @Test + public void testGcsRequiredTempLocation() { + // Error raised if temp location not set. + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setProject("someProject"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("tempLocation")); + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testNonGcsFilePathInReadFailure() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadMyNonGcsFile").from("/tmp/file")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("GCS URI")); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testNonGcsFilePathInWriteFailure() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteMyNonGcsFile").to("/tmp/file")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("GCS URI")); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testMultiSlashGcsFileReadPath() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadInvalidGcsFile") + .from("gs://bucket/tmp//file")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("consecutive slashes"); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testMultiSlashGcsFileWritePath() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteInvalidGcsFile") + .to("gs://bucket/tmp//file")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("consecutive slashes"); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testInvalidTempLocation() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setTempLocation("file://temp/location"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("GCS URI")); + DataflowPipelineRunner.fromOptions(options); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testInvalidStagingLocation() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setStagingLocation("file://my/staging/location"); + try { + DataflowPipelineRunner.fromOptions(options); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), containsString("GCS URI")); + } + options.setStagingLocation("my/staging/location"); + try { + DataflowPipelineRunner.fromOptions(options); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), containsString("GCS URI")); + } + } + + @Test + public void testInvalidJobName() throws IOException { + List invalidNames = Arrays.asList( + "invalid_name", + "0invalid", + "invalid-", + "this-one-is-too-long-01234567890123456789"); + List expectedReason = Arrays.asList( + "JobName invalid", + "JobName invalid", + "JobName invalid", + "JobName too long"); + + for (int i = 0; i < invalidNames.size(); ++i) { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setJobName(invalidNames.get(i)); + + try { + DataflowPipelineRunner.fromOptions(options); + fail("Expected IllegalArgumentException for jobName " + + options.getJobName()); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), + containsString(expectedReason.get(i))); + } + } + } + + @Test + public void testValidJobName() throws IOException { + List names = Arrays.asList("ok", "Ok", "A-Ok", "ok-123"); + + for (String name : names) { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setJobName(name); + + DataflowPipelineRunner runner = DataflowPipelineRunner + .fromOptions(options); + assertNotNull(runner); + } + } + + /** + * A fake PTransform for testing. + */ + public static class TestTransform + extends PTransform, PCollection> { + public boolean translated = false; + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal(new GlobalWindow()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return getInput().getCoder(); + } + } + + @Test + public void testTransformTranslatorMissing() throws IOException { + // Test that we throw if we don't provide a translation. + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + Pipeline p = DataflowPipeline.create(options); + + p.apply(Create.of(Arrays.asList(1, 2, 3))) + .apply(new TestTransform()); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage(Matchers.containsString("no translator registered")); + DataflowPipelineTranslator.fromOptions(options) + .translate(p, Collections.emptyList()); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testTransformTranslator() throws IOException { + // Test that we can provide a custom translation + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + DataflowPipeline p = DataflowPipeline.create(options); + TestTransform transform = new TestTransform(); + + p.apply(Create.of(Arrays.asList(1, 2, 3))) + .apply(transform) + .setCoder(BigEndianIntegerCoder.of()); + + DataflowPipelineTranslator translator = DataflowPipelineRunner + .fromOptions(options).getTranslator(); + + translator.registerTransformTranslator( + TestTransform.class, + new DataflowPipelineTranslator.TransformTranslator() { + @SuppressWarnings("unchecked") + @Override + public void translate( + TestTransform transform, + DataflowPipelineTranslator.TranslationContext context) { + transform.translated = true; + + // Note: This is about the minimum needed to fake out a + // translation. This obviously isn't a real translation. + context.addStep(transform, "TestTranslate"); + context.addOutput("output", transform.getOutput()); + } + }); + + translator.translate(p, Collections.emptyList()); + assertTrue(transform.translated); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java new file mode 100644 index 000000000000..e2edb9fdc223 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java @@ -0,0 +1,582 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; +import static com.google.cloud.dataflow.sdk.util.Structs.getDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.Step; +import com.google.api.services.dataflow.model.WorkerPool; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.OutputReference; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Iterables; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatcher; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * Tests for DataflowPipelineTranslator. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineTranslatorTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + // A Custom Mockito matcher for an initial Job which checks that all + // expected fields are set. + private static class IsValidCreateRequest extends ArgumentMatcher { + public boolean matches(Object o) { + Job job = (Job) o; + return job.getId() == null + && job.getProjectId() == null + && job.getName() != null + && job.getType() != null + && job.getEnvironment() != null + && job.getSteps() != null + && job.getCurrentState() == null + && job.getCurrentStateTime() == null + && job.getExecutionInfo() == null + && job.getCreateTime() == null; + } + } + + private DataflowPipeline buildPipeline(DataflowPipelineOptions options) + throws IOException { + DataflowPipeline p = DataflowPipeline.create(options); + + p.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/object")); + + return p; + } + + private static Dataflow buildMockDataflow( + ArgumentMatcher jobMatcher) throws IOException { + Dataflow mockDataflowClient = mock(Dataflow.class); + Dataflow.V1b3 mockV1b3 = mock(Dataflow.V1b3.class); + Dataflow.V1b3.Projects mockProjects = mock(Dataflow.V1b3.Projects.class); + Dataflow.V1b3.Projects.Jobs mockJobs = mock(Dataflow.V1b3.Projects.Jobs.class); + Dataflow.V1b3.Projects.Jobs.Create mockRequest = mock( + Dataflow.V1b3.Projects.Jobs.Create.class); + + when(mockDataflowClient.v1b3()).thenReturn(mockV1b3); + when(mockV1b3.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + when(mockJobs.create(eq("someProject"), argThat(jobMatcher))) + .thenReturn(mockRequest); + + Job resultJob = new Job(); + resultJob.setId("newid"); + when(mockRequest.execute()).thenReturn(resultJob); + return mockDataflowClient; + } + + private static DataflowPipelineOptions buildPipelineOptions() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + options.setProject("some-project"); + options.setTempLocation(GcsPath.fromComponents("somebucket", "some/path").toString()); + options.setFilesToStage(new LinkedList()); + options.setDataflowClient(buildMockDataflow(new IsValidCreateRequest())); + return options; + } + + @Test + public void testZoneConfig() throws IOException { + final String testZone = "test-zone-1"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setZone(testZone); + + Pipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = DataflowPipelineTranslator.fromOptions(options).translate( + p, Collections.emptyList()); + + assertEquals(2, job.getEnvironment().getWorkerPools().size()); + assertEquals(testZone, + job.getEnvironment().getWorkerPools().get(0).getZone()); + assertEquals(testZone, + job.getEnvironment().getWorkerPools().get(1).getZone()); + } + + @Test + public void testWorkerMachineTypeConfig() throws IOException { + final String testMachineType = "test-machine-type"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setWorkerMachineType(testMachineType); + + Pipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = DataflowPipelineTranslator.fromOptions(options).translate( + p, Collections.emptyList()); + + assertEquals(2, job.getEnvironment().getWorkerPools().size()); + + WorkerPool workerPool = null; + + if (job + .getEnvironment() + .getWorkerPools() + .get(0) + .getKind() + .equals(DataflowPipelineTranslator.HARNESS_WORKER_POOL)) { + workerPool = job.getEnvironment().getWorkerPools().get(0); + } else if (job + .getEnvironment() + .getWorkerPools() + .get(1) + .getKind() + .equals(DataflowPipelineTranslator.HARNESS_WORKER_POOL)) { + workerPool = job.getEnvironment().getWorkerPools().get(1); + } else { + fail("Missing worker pool."); + } + assertEquals(testMachineType, workerPool.getMachineType()); + } + + @Test + public void testDiskSizeGbConfig() throws IOException { + final Integer diskSizeGb = 1234; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setDiskSizeGb(diskSizeGb); + + Pipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = DataflowPipelineTranslator.fromOptions(options).translate( + p, Collections.emptyList()); + + assertEquals(2, job.getEnvironment().getWorkerPools().size()); + assertEquals(diskSizeGb, + job.getEnvironment().getWorkerPools().get(0).getDiskSizeGb()); + assertEquals(diskSizeGb, + job.getEnvironment().getWorkerPools().get(1).getDiskSizeGb()); + } + + @Test + public void testShufflePoolConfig() throws IOException { + final Integer numWorkers = 10; + final String diskSource = "test-disk-source"; + final Integer diskSizeGb = 12345; + final String zone = "test-zone-1"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setShuffleNumWorkers(numWorkers); + options.setShuffleDiskSourceImage(diskSource); + options.setShuffleDiskSizeGb(diskSizeGb); + options.setShuffleZone(zone); + + Pipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = DataflowPipelineTranslator.fromOptions(options).translate( + p, Collections.emptyList()); + + assertEquals(2, job.getEnvironment().getWorkerPools().size()); + WorkerPool shufflePool = + job.getEnvironment().getWorkerPools().get(1); + assertEquals(shufflePool.getKind(), + DataflowPipelineTranslator.SHUFFLE_WORKER_POOL); + assertEquals(numWorkers, shufflePool.getNumWorkers()); + assertEquals(diskSource, shufflePool.getDiskSourceImage()); + assertEquals(diskSizeGb, shufflePool.getDiskSizeGb()); + assertEquals(zone, shufflePool.getZone()); + } + + @Test + public void testPredefinedAddStep() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + DataflowPipelineTranslator.registerTransformTranslator( + EmbeddedTransform.class, new EmbeddedTranslator()); + + // Create a predefined step using another pipeline + Step predefinedStep = createPredefinedStep(); + + // Create a pipeline that the predefined step will be embedded into + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/in")) + .apply(ParDo.of(new NoOpFn())) + .apply(new EmbeddedTransform(predefinedStep.clone())) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/out")); + Job job = translator.translate(pipeline, Collections.emptyList()); + + List steps = job.getSteps(); + assertEquals(4, steps.size()); + + // The input to the embedded step should match the output of the step before + Map step1Out = getOutputPortReference(steps.get(1)); + Map step2In = getDictionary( + steps.get(2).getProperties(), PropertyNames.PARALLEL_INPUT); + assertEquals(step1Out, step2In); + + // The output from the embedded step should match the input of the step after + Map step2Out = getOutputPortReference(steps.get(2)); + Map step3In = getDictionary( + steps.get(3).getProperties(), PropertyNames.PARALLEL_INPUT); + assertEquals(step2Out, step3In); + + // The step should not have been modified other than remapping the input + Step predefinedStepClone = predefinedStep.clone(); + Step embeddedStepClone = steps.get(2).clone(); + predefinedStepClone.getProperties().remove(PropertyNames.PARALLEL_INPUT); + embeddedStepClone.getProperties().remove(PropertyNames.PARALLEL_INPUT); + assertEquals(predefinedStepClone, embeddedStepClone); + } + + /** + * Construct a OutputReference for the output of the step. + */ + private static OutputReference getOutputPortReference(Step step) throws Exception { + // TODO: This should be done via a Structs accessor. + List> output = + (List>) step.getProperties().get(PropertyNames.OUTPUT_INFO); + String outputTagId = getString(Iterables.getOnlyElement(output), PropertyNames.OUTPUT_NAME); + return new OutputReference(step.getName(), outputTagId); + } + + /** + * Returns a Step for a DoFn by creating and translating a pipeline. + */ + private static Step createPredefinedStep() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + DataflowPipeline pipeline = DataflowPipeline.create(options); + String stepName = "DoFn1"; + pipeline.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/in")) + .apply(ParDo.of(new NoOpFn()).named(stepName)) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/out")); + Job job = translator.translate(pipeline, Collections.emptyList()); + + assertEquals(3, job.getSteps().size()); + Step step = job.getSteps().get(1); + assertEquals(stepName, getString(step.getProperties(), PropertyNames.USER_NAME)); + return step; + } + + private static class NoOpFn extends DoFn{ + @Override public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + } + + /** + * A placeholder transform that will be used to substitute a predefined Step. + */ + private static class EmbeddedTransform + extends PTransform, PCollection> { + private final Step step; + + public EmbeddedTransform(Step step) { + this.step = step; + } + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal(new GlobalWindow()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return StringUtf8Coder.of(); + } + } + + /** + * A TransformTranslator that adds the predefined Step using + * {@link TranslationContext#addStep} and remaps the input port reference. + */ + private static class EmbeddedTranslator + implements DataflowPipelineTranslator.TransformTranslator { + @Override public void translate(EmbeddedTransform transform, TranslationContext context) { + addObject(transform.step.getProperties(), PropertyNames.PARALLEL_INPUT, + context.asOutputReference(transform.getInput())); + context.addStep(transform, transform.step); + } + } + + /** + * A composite transform which returns an output that is unrelated to + * the input. + */ + private static class UnrelatedOutputCreator + extends PTransform, PCollection> { + + @Override + public PCollection apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + // Return a value unrelated to the input. + return input.getPipeline().apply(Create.of(1, 2, 3, 4)); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VarIntCoder.of(); + } + } + + /** + * A composite transform which returns an output which is unbound. + */ + private static class UnboundOutputCreator + extends PTransform, PDone> { + + @Override + public PDone apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + return new PDone(); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + } + + /** + * A composite transform which returns a partially bound output. + * + *

This is not allowed and will result in a failure. + */ + private static class PartiallyBoundOutputCreator + extends PTransform, PCollectionTuple> { + + public final TupleTag sumTag = new TupleTag<>("sum"); + public final TupleTag doneTag = new TupleTag<>("done"); + + @Override + public PCollectionTuple apply(PCollection input) { + PCollection sum = input.apply(Sum.integersGlobally()); + + // Fails here when attempting to construct a tuple with an unbound object. + return PCollectionTuple.of(sumTag, sum) + .and(doneTag, PCollection.createPrimitiveOutputInternal( + new GlobalWindow())); + } + } + + @Test + public void testMultiGraphPipelineSerialization() throws IOException { + Pipeline p = DataflowPipeline.create(buildPipelineOptions()); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + input.apply(new UnrelatedOutputCreator()); + input.apply(new UnboundOutputCreator()); + + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions( + PipelineOptionsFactory.as(DataflowPipelineOptions.class)); + + // Check that translation doesn't fail. + t.translate(p, Collections.emptyList()); + } + + @Test + public void testPartiallyBoundFailure() throws IOException { + Pipeline p = DataflowPipeline.create(buildPipelineOptions()); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + thrown.expect(IllegalStateException.class); + input.apply(new PartiallyBoundOutputCreator()); + + Assert.fail("Failure expected from use of partially bound output"); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests a few corner cases that should not crash. + */ + @Test + public void testGoodWildcards() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + Pipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/[0-9]")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*baz*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/*baz?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/[0-9]baz?")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/baz/*")); + pipeline.apply(TextIO.Read.from("gs://bucket/foo/baz/*wonka*")); + + // Check that translation doesn't fail. + t.translate(pipeline, Collections.emptyList()); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "*". + */ + @Test + public void testBadWildcardStar() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + Pipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo*/baz")); + + // Check that translation does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unsupported wildcard usage"); + t.translate(pipeline, Collections.emptyList()); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "?". + */ + @Test + public void testBadWildcardOptional() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + Pipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo?/baz")); + + // Check that translation does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unsupported wildcard usage"); + t.translate(pipeline, Collections.emptyList()); + } + + /** + * The first wildcard must occur after the last directory delimiter. + * This tests "[]" based character classes. + */ + @Test + public void testBadWildcardBrackets() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + Pipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo[0-9]/baz")); + + // Check that translation does fail. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unsupported wildcard usage"); + t.translate(pipeline, Collections.emptyList()); + } + + @Test + public void testToSingletonTranslation() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + PCollectionView view = pipeline + .apply(Create.of(1)) + .apply(View.asSingleton()); + Job job = translator.translate(pipeline, Collections.emptyList()); + + List steps = job.getSteps(); + assertEquals(2, steps.size()); + + Step createStep = steps.get(0); + assertEquals("CreateCollection", createStep.getKind()); + + Step collectionToSingletonStep = steps.get(1); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + + } + + @Test + public void testToIterableTranslation() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView, ...> does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + PCollectionView, ?> view = pipeline + .apply(Create.of(1, 2, 3)) + .apply(View.asIterable()); + Job job = translator.translate(pipeline, Collections.emptyList()); + + List steps = job.getSteps(); + assertEquals(2, steps.size()); + + Step createStep = steps.get(0); + assertEquals("CreateCollection", createStep.getKind()); + + Step collectionToSingletonStep = steps.get(1); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java new file mode 100644 index 000000000000..520e03e28b9d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertTrue; + +import com.google.api.services.dataflow.Dataflow; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.TestCredential; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; + +/** + * Tests for DataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class PipelineRunnerTest { + + @Mock private Dataflow mockDataflow; + @Mock private GcsUtil mockGcsUtil; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testLongName() throws IOException { + // Check we can create a pipeline runner using the full class name. + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setAppName("test"); + options.setProject("test"); + options.setGcsUtil(mockGcsUtil); + options.setRunner(DirectPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + PipelineRunner runner = PipelineRunner.fromOptions(options); + assertTrue(runner instanceof DirectPipelineRunner); + } + + @Test + public void testShortName() throws IOException { + // Check we can create a pipeline runner using the short class name. + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setAppName("test"); + options.setProject("test"); + options.setGcsUtil(mockGcsUtil); + options.setRunner(DirectPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + PipelineRunner runner = PipelineRunner.fromOptions(options); + assertTrue(runner instanceof DirectPipelineRunner); + } + + @Test + public void testAppNameDefault() throws IOException { + ApplicationNameOptions options = PipelineOptionsFactory.as(ApplicationNameOptions.class); + Assert.assertEquals(PipelineRunnerTest.class.getSimpleName(), + options.getAppName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java new file mode 100644 index 000000000000..d0308e87a33f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.First; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; + +/** + * Tests for {@link TransformTreeNode} and {@link TransformHierarchy}. + */ +@RunWith(JUnit4.class) +public class TransformTreeTest { + + enum TransformsSeen { + READ, + WRITE, + FIRST + } + + /** + * INVALID TRANSFORM, DO NOT COPY. + * + *

This is an invalid composite transform, which returns unbound outputs. + * This should never happen, and is here to test that it is properly rejected. + */ + private static class InvalidCompositeTransform + extends PTransform> { + + @Override + public PCollectionList apply(PBegin b) { + // Composite transform: apply delegates to other transformations, + // here a Create transform. + PCollection result = b.apply(Create.of("hello", "world")); + + // Issue below: PCollection.createPrimitiveOutput should not be used + // from within a composite transform. + return PCollectionList.of( + Arrays.asList(result, PCollection.createPrimitiveOutputInternal( + new GlobalWindow()))); + } + } + + /** + * A composite transform which returns an output which is unbound. + */ + private static class UnboundOutputCreator + extends PTransform, PDone> { + + @Override + public PDone apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + return new PDone(); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + } + + // Builds a pipeline containing a composite operation (First), then + // visits the nodes and verifies that the hierarchy was captured. + @Test + public void testCompositeCapture() throws Exception { + Pipeline p = DirectPipeline.createForTest(); + + p.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/object")) + .apply(First.of(10)) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/object")); + + final EnumSet visited = + EnumSet.noneOf(TransformsSeen.class); + final EnumSet left = + EnumSet.noneOf(TransformsSeen.class); + + p.traverseTopologically(new Pipeline.PipelineVisitor() { + @Override + public void enterCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform instanceof First) { + Assert.assertTrue(visited.add(TransformsSeen.FIRST)); + Assert.assertNotNull(node.getEnclosingNode()); + Assert.assertTrue(node.isCompositeNode()); + } + Assert.assertThat(transform, not(instanceOf(TextIO.Read.Bound.class))); + Assert.assertThat(transform, not(instanceOf(TextIO.Write.Bound.class))); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform instanceof First) { + Assert.assertTrue(left.add(TransformsSeen.FIRST)); + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + // First is a composite, should not be visited here. + Assert.assertThat(transform, not(instanceOf(First.class))); + if (transform instanceof TextIO.Read.Bound) { + Assert.assertTrue(visited.add(TransformsSeen.READ)); + } else if (transform instanceof TextIO.Write.Bound) { + Assert.assertTrue(visited.add(TransformsSeen.WRITE)); + } + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + } + }); + + Assert.assertTrue(visited.equals(EnumSet.allOf(TransformsSeen.class))); + Assert.assertTrue(left.equals(EnumSet.of(TransformsSeen.FIRST))); + } + + @Test(expected = IllegalStateException.class) + public void testOutputChecking() throws Exception { + Pipeline p = DirectPipeline.createForTest(); + + p.apply(new InvalidCompositeTransform()); + + p.traverseTopologically(new RecordingPipelineVisitor()); + Assert.fail("traversal should have failed with an IllegalStateException"); + } + + @Test + public void testMultiGraphSetup() throws IOException { + Pipeline p = DirectPipeline.createForTest(); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + input.apply(new UnboundOutputCreator()); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSinkTest.java new file mode 100644 index 000000000000..3c7f29b40a8c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSinkTest.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.SeekableInput; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.io.DatumReader; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for AvroByteSink. + */ +@RunWith(JUnit4.class) +public class AvroByteSinkTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + void runTestWriteFile(List elems, Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.avro"); + String filename = tmpFile.getPath(); + + // Write the file. + + AvroByteSink avroSink = new AvroByteSink<>(filename, coder); + List actualSizes = new ArrayList<>(); + try (Sink.SinkWriter writer = avroSink.writer()) { + for (T elem : elems) { + actualSizes.add(writer.add(elem)); + } + } + + // Read back the file. + + SeekableByteChannel inChannel = (SeekableByteChannel) + IOChannelUtils.getFactory(filename).open(filename); + + SeekableInput seekableInput = + new AvroSource.SeekableByteChannelInput(inChannel); + + Schema schema = Schema.create(Schema.Type.BYTES); + + DatumReader datumReader = new GenericDatumReader<>(schema); + + DataFileReader fileReader = new DataFileReader<>( + seekableInput, datumReader); + + List actual = new ArrayList<>(); + List expectedSizes = new ArrayList<>(); + ByteBuffer inBuffer = ByteBuffer.allocate(10 * 1024); + while (fileReader.hasNext()) { + inBuffer = fileReader.next(inBuffer); + byte[] encodedElem = new byte[inBuffer.remaining()]; + inBuffer.get(encodedElem); + assert inBuffer.remaining() == 0; + inBuffer.clear(); + T elem = CoderUtils.decodeFromByteArray(coder, encodedElem); + actual.add(elem); + expectedSizes.add((long) encodedElem.length); + } + + fileReader.close(); + + // Compare the expected and the actual elements. + Assert.assertEquals(elems, actual); + Assert.assertEquals(expectedSizes, actualSizes); + } + + @Test + public void testWriteFile() throws Exception { + runTestWriteFile(TestUtils.INTS, BigEndianIntegerCoder.of()); + } + + @Test + public void testWriteEmptyFile() throws Exception { + runTestWriteFile(TestUtils.NO_INTS, BigEndianIntegerCoder.of()); + } + + // TODO: sharded filenames + // TODO: writing to GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSourceTest.java new file mode 100644 index 000000000000..e6bfffdcb68a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroByteSourceTest.java @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.io.DatumWriter; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import javax.annotation.Nullable; + +/** + * Tests for AvroByteSource. + */ +@RunWith(JUnit4.class) +public class AvroByteSourceTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private void runTestRead(List> elemsList, + Coder coder, + boolean requireExactMatch) + throws Exception { + File tmpFile = tmpFolder.newFile("file.avro"); + String filename = tmpFile.getPath(); + + // Write the data. + OutputStream outStream = Channels.newOutputStream( + IOChannelUtils.create(filename, MimeTypes.BINARY)); + Schema schema = Schema.create(Schema.Type.BYTES); + DatumWriter datumWriter = new GenericDatumWriter<>(schema); + DataFileWriter fileWriter = new DataFileWriter<>(datumWriter); + fileWriter.create(schema, outStream); + boolean first = true; + List syncPoints = new ArrayList<>(); + List expectedSizes = new ArrayList<>(); + for (List elems : elemsList) { + if (first) { + first = false; + } else { + // Ensure a block boundary here. + long syncPoint = fileWriter.sync(); + syncPoints.add(syncPoint); + } + for (T elem : elems) { + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + fileWriter.append(ByteBuffer.wrap(encodedElem)); + expectedSizes.add(encodedElem.length); + } + } + fileWriter.close(); + + // Test reading the data back. + List> actualElemsList = new ArrayList<>(); + List actualSizes = new ArrayList<>(); + Long startOffset = null; + Long endOffset; + long prevSyncPoint = 0; + for (long syncPoint : syncPoints) { + endOffset = (prevSyncPoint + syncPoint) / 2; + actualElemsList.add(readElems(filename, startOffset, endOffset, coder, actualSizes)); + startOffset = endOffset; + prevSyncPoint = syncPoint; + } + actualElemsList.add(readElems(filename, startOffset, null, coder, actualSizes)); + + // Compare the expected and the actual elements. + if (requireExactMatch) { + // Require the blocks to match exactly. (This works only for + // small block sizes. Large block sizes, bigger than Avro's + // internal sizes, lead to different splits.) + Assert.assertEquals(elemsList, actualElemsList); + } else { + // Just require the overall elements to be the same. (This + // works for any block size.) + List expected = new ArrayList<>(); + for (List elems : elemsList) { + expected.addAll(elems); + } + List actual = new ArrayList<>(); + for (List actualElems : actualElemsList) { + actual.addAll(actualElems); + } + Assert.assertEquals(expected, actual); + } + + Assert.assertEquals(expectedSizes, actualSizes); + } + + private List readElems(String filename, + @Nullable Long startOffset, + @Nullable Long endOffset, + Coder coder, + List actualSizes) + throws Exception { + AvroByteSource avroSource = + new AvroByteSource<>(filename, startOffset, endOffset, coder); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(avroSource, actualSizes); + + List actualElems = new ArrayList<>(); + try (Source.SourceIterator iterator = avroSource.iterator()) { + while (iterator.hasNext()) { + actualElems.add(iterator.next()); + } + } + return actualElems; + } + + @Test + public void testRead() throws Exception { + runTestRead(Collections.singletonList(TestUtils.INTS), + BigEndianIntegerCoder.of(), + true /* require exact match */); + } + + @Test + public void testReadEmpty() throws Exception { + runTestRead(Collections.singletonList(TestUtils.NO_INTS), + BigEndianIntegerCoder.of(), + true /* require exact match */); + } + + private List> generateInputBlocks(int numBlocks, + int blockSizeBytes, + int averageLineSizeBytes) { + Random random = new Random(0); + List> blocks = new ArrayList<>(numBlocks); + for (int blockNum = 0; blockNum < numBlocks; blockNum++) { + int numLines = blockSizeBytes / averageLineSizeBytes; + List lines = new ArrayList<>(numLines); + for (int lineNum = 0; lineNum < numLines; lineNum++) { + int numChars = random.nextInt(averageLineSizeBytes * 2); + StringBuilder sb = new StringBuilder(); + for (int charNum = 0; charNum < numChars; charNum++) { + sb.appendCodePoint(random.nextInt('z' - 'a' + 1) + 'a'); + } + lines.add(sb.toString()); + } + blocks.add(lines); + } + return blocks; + } + + @Test + public void testReadSmallRanges() throws Exception { + runTestRead(generateInputBlocks(3, 50, 5), + StringUtf8Coder.of(), + true /* require exact match */); + } + + @Test + public void testReadBigRanges() throws Exception { + runTestRead(generateInputBlocks(10, 128 * 1024, 100), + StringUtf8Coder.of(), + false /* don't require exact match */); + } + + // TODO: sharded filenames + // TODO: reading from GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactoryTest.java new file mode 100644 index 000000000000..79653feabc4c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkFactoryTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for AvroSinkFactory. + */ +@RunWith(JUnit4.class) +public class AvroSinkFactoryTest { + private final String pathToAvroFile = "/path/to/file.avro"; + + Sink runTestCreateAvroSink(String filename, + CloudObject encoding) + throws Exception { + CloudObject spec = CloudObject.forClassName("AvroSink"); + addString(spec, "filename", filename); + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(encoding); + + Sink sink = SinkFactory.create(PipelineOptionsFactory.create(), cloudSink, + new BatchModeExecutionContext()); + return sink; + } + + @Test + public void testCreateAvroByteSink() throws Exception { + Coder coder = + WindowedValue.getValueOnlyCoder(BigEndianIntegerCoder.of()); + Sink sink = runTestCreateAvroSink( + pathToAvroFile, coder.asCloudObject()); + + Assert.assertThat(sink, new IsInstanceOf(AvroByteSink.class)); + AvroByteSink avroSink = (AvroByteSink) sink; + Assert.assertEquals(pathToAvroFile, avroSink.avroSink.filenamePrefix); + Assert.assertEquals(coder, avroSink.coder); + } + + @Test + public void testCreateAvroSink() throws Exception { + WindowedValue.WindowedValueCoder coder = + WindowedValue.getValueOnlyCoder(AvroCoder.of(Integer.class)); + Sink sink = runTestCreateAvroSink(pathToAvroFile, coder.asCloudObject()); + + Assert.assertThat(sink, new IsInstanceOf(AvroSink.class)); + AvroSink avroSink = (AvroSink) sink; + Assert.assertEquals(pathToAvroFile, avroSink.filenamePrefix); + Assert.assertEquals(coder.getValueCoder(), avroSink.avroCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkTest.java new file mode 100644 index 000000000000..5f22d2774f4b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSinkTest.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.SeekableInput; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.io.DatumReader; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.channels.SeekableByteChannel; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for AvroSink. + */ +@RunWith(JUnit4.class) +public class AvroSinkTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + void runTestWriteFile(List elems, AvroCoder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.avro"); + String filename = tmpFile.getPath(); + + // Write the file. + + AvroSink avroSink = new AvroSink<>(filename, WindowedValue.getValueOnlyCoder(coder)); + List actualSizes = new ArrayList<>(); + try (Sink.SinkWriter> writer = avroSink.writer()) { + for (T elem : elems) { + actualSizes.add(writer.add(WindowedValue.valueInGlobalWindow(elem))); + } + } + + // Read back the file. + + SeekableByteChannel inChannel = (SeekableByteChannel) + IOChannelUtils.getFactory(filename).open(filename); + + SeekableInput seekableInput = + new AvroSource.SeekableByteChannelInput(inChannel); + + DatumReader datumReader = new GenericDatumReader<>(coder.getSchema()); + + DataFileReader fileReader = new DataFileReader<>( + seekableInput, datumReader); + + List actual = new ArrayList<>(); + List expectedSizes = new ArrayList<>(); + while (fileReader.hasNext()) { + T next = fileReader.next(); + actual.add(next); + expectedSizes.add((long) CoderUtils.encodeToByteArray(coder, next).length); + } + + fileReader.close(); + + // Compare the expected and the actual elements. + Assert.assertEquals(elems, actual); + Assert.assertEquals(expectedSizes, actualSizes); + } + + @Test + public void testWriteFile() throws Exception { + runTestWriteFile(TestUtils.INTS, AvroCoder.of(Integer.class)); + } + + @Test + public void testWriteEmptyFile() throws Exception { + runTestWriteFile(TestUtils.NO_INTS, AvroCoder.of(Integer.class)); + } + + // TODO: sharded filenames + // TODO: writing to GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactoryTest.java new file mode 100644 index 000000000000..3c81950fd29d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceFactoryTest.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import javax.annotation.Nullable; + +/** + * Tests for AvroSourceFactory. + */ +@RunWith(JUnit4.class) +public class AvroSourceFactoryTest { + private final String pathToAvroFile = "/path/to/file.avro"; + + Source runTestCreateAvroSource(String filename, + @Nullable Long start, + @Nullable Long end, + CloudObject encoding) + throws Exception { + CloudObject spec = CloudObject.forClassName("AvroSource"); + addString(spec, "filename", filename); + if (start != null) { + addLong(spec, "start_offset", start); + } + if (end != null) { + addLong(spec, "end_offset", end); + } + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(encoding); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + return source; + } + + @Test + public void testCreatePlainAvroByteSource() throws Exception { + Coder coder = + WindowedValue.getValueOnlyCoder(BigEndianIntegerCoder.of()); + Source source = runTestCreateAvroSource( + pathToAvroFile, null, null, coder.asCloudObject()); + + Assert.assertThat(source, new IsInstanceOf(AvroByteSource.class)); + AvroByteSource avroSource = (AvroByteSource) source; + Assert.assertEquals(pathToAvroFile, avroSource.avroSource.filename); + Assert.assertEquals(null, avroSource.avroSource.startPosition); + Assert.assertEquals(null, avroSource.avroSource.endPosition); + Assert.assertEquals(coder, avroSource.coder); + } + + @Test + public void testCreateRichAvroByteSource() throws Exception { + Coder coder = + WindowedValue.getValueOnlyCoder(BigEndianIntegerCoder.of()); + Source source = runTestCreateAvroSource( + pathToAvroFile, 200L, 500L, coder.asCloudObject()); + + Assert.assertThat(source, new IsInstanceOf(AvroByteSource.class)); + AvroByteSource avroSource = (AvroByteSource) source; + Assert.assertEquals(pathToAvroFile, avroSource.avroSource.filename); + Assert.assertEquals(200L, (long) avroSource.avroSource.startPosition); + Assert.assertEquals(500L, (long) avroSource.avroSource.endPosition); + Assert.assertEquals(coder, avroSource.coder); + } + + @Test + public void testCreateRichAvroSource() throws Exception { + WindowedValue.WindowedValueCoder coder = + WindowedValue.getValueOnlyCoder(AvroCoder.of(Integer.class)); + Source source = runTestCreateAvroSource( + pathToAvroFile, 200L, 500L, coder.asCloudObject()); + + Assert.assertThat(source, new IsInstanceOf(AvroSource.class)); + AvroSource avroSource = (AvroSource) source; + Assert.assertEquals(pathToAvroFile, avroSource.filename); + Assert.assertEquals(200L, (long) avroSource.startPosition); + Assert.assertEquals(500L, (long) avroSource.endPosition); + Assert.assertEquals(coder.getValueCoder(), avroSource.avroCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceTest.java new file mode 100644 index 000000000000..4855ef92e4d9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/AvroSourceTest.java @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.io.DatumWriter; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import javax.annotation.Nullable; + +/** + * Tests for AvroSource. + */ +@RunWith(JUnit4.class) +public class AvroSourceTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private void runTestRead(List> elemsList, + AvroCoder coder, + boolean requireExactMatch) + throws Exception { + File tmpFile = tmpFolder.newFile("file.avro"); + String filename = tmpFile.getPath(); + + // Write the data. + OutputStream outStream = Channels.newOutputStream( + IOChannelUtils.create(filename, MimeTypes.BINARY)); + DatumWriter datumWriter = coder.createDatumWriter(); + DataFileWriter fileWriter = new DataFileWriter<>(datumWriter); + fileWriter.create(coder.getSchema(), outStream); + boolean first = true; + List syncPoints = new ArrayList<>(); + List expectedSizes = new ArrayList<>(); + for (List elems : elemsList) { + if (first) { + first = false; + } else { + // Ensure a block boundary here. + long syncPoint = fileWriter.sync(); + syncPoints.add(syncPoint); + } + for (T elem : elems) { + fileWriter.append(elem); + expectedSizes.add(CoderUtils.encodeToByteArray(coder, elem).length); + } + } + fileWriter.close(); + + // Test reading the data back. + List> actualElemsList = new ArrayList<>(); + List actualSizes = new ArrayList<>(); + Long startOffset = null; + Long endOffset; + long prevSyncPoint = 0; + for (long syncPoint : syncPoints) { + endOffset = (prevSyncPoint + syncPoint) / 2; + actualElemsList.add(readElems(filename, startOffset, endOffset, coder, actualSizes)); + startOffset = endOffset; + prevSyncPoint = syncPoint; + } + actualElemsList.add(readElems(filename, startOffset, null, coder, actualSizes)); + + // Compare the expected and the actual elements. + if (requireExactMatch) { + // Require the blocks to match exactly. (This works only for + // small block sizes. Large block sizes, bigger than Avro's + // internal sizes, lead to different splits.) + Assert.assertEquals(elemsList, actualElemsList); + } else { + // Just require the overall elements to be the same. (This + // works for any block size.) + List expected = new ArrayList<>(); + for (List elems : elemsList) { + expected.addAll(elems); + } + List actual = new ArrayList<>(); + for (List actualElems : actualElemsList) { + actual.addAll(actualElems); + } + Assert.assertEquals(expected, actual); + } + + Assert.assertEquals(expectedSizes, actualSizes); + } + + private List readElems(String filename, + @Nullable Long startOffset, + @Nullable Long endOffset, + Coder coder, + List actualSizes) + throws Exception { + AvroSource avroSource = + new AvroSource<>(filename, startOffset, endOffset, WindowedValue.getValueOnlyCoder(coder)); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(avroSource, actualSizes); + + List actualElems = new ArrayList<>(); + try (Source.SourceIterator> iterator = avroSource.iterator()) { + while (iterator.hasNext()) { + actualElems.add(iterator.next().getValue()); + } + } + return actualElems; + } + + @Test + public void testRead() throws Exception { + runTestRead(Collections.singletonList(TestUtils.INTS), + AvroCoder.of(Integer.class), + true /* require exact match */); + } + + @Test + public void testReadEmpty() throws Exception { + runTestRead(Collections.singletonList(TestUtils.NO_INTS), + AvroCoder.of(Integer.class), + true /* require exact match */); + } + + private List> generateInputBlocks(int numBlocks, + int blockSizeBytes, + int averageLineSizeBytes) { + Random random = new Random(0); + List> blocks = new ArrayList<>(numBlocks); + for (int blockNum = 0; blockNum < numBlocks; blockNum++) { + int numLines = blockSizeBytes / averageLineSizeBytes; + List lines = new ArrayList<>(numLines); + for (int lineNum = 0; lineNum < numLines; lineNum++) { + int numChars = random.nextInt(averageLineSizeBytes * 2); + StringBuilder sb = new StringBuilder(); + for (int charNum = 0; charNum < numChars; charNum++) { + sb.appendCodePoint(random.nextInt('z' - 'a' + 1) + 'a'); + } + lines.add(sb.toString()); + } + blocks.add(lines); + } + return blocks; + } + + @Test + public void testReadSmallRanges() throws Exception { + runTestRead(generateInputBlocks(3, 50, 5), + AvroCoder.of(String.class), + true /* require exact match */); + } + + @Test + public void testReadBigRanges() throws Exception { + runTestRead(generateInputBlocks(10, 128 * 1024, 100), + AvroCoder.of(String.class), + false /* don't require exact match */); + } + + // TODO: sharded filenames + // TODO: reading from GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactoryTest.java new file mode 100644 index 000000000000..0eb95c70205c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceFactoryTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for BigQuerySourceFactory. + */ +@RunWith(JUnit4.class) +public class BigQuerySourceFactoryTest { + void runTestCreateBigQuerySource(String project, + String dataset, + String table, + CloudObject encoding) + throws Exception { + CloudObject spec = CloudObject.forClassName("BigQuerySource"); + addString(spec, "project", project); + addString(spec, "dataset", dataset); + addString(spec, "table", table); + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(encoding); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + Assert.assertThat(source, new IsInstanceOf(BigQuerySource.class)); + BigQuerySource bigQuerySource = (BigQuerySource) source; + Assert.assertEquals(project, bigQuerySource.tableRef.getProjectId()); + Assert.assertEquals(dataset, bigQuerySource.tableRef.getDatasetId()); + Assert.assertEquals(table, bigQuerySource.tableRef.getTableId()); + } + + @Test + public void testCreateBigQuerySource() throws Exception { + runTestCreateBigQuerySource( + "someproject", "somedataset", "sometable", + makeCloudEncoding("TableRowJsonCoder")); + } + + @Test + public void testCreateBigQuerySourceCoderIgnored() throws Exception { + // BigQuery sources do not need a coder because the TableRow objects are read directly from + // the table using the BigQuery API. + runTestCreateBigQuerySource( + "someproject", "somedataset", "sometable", + makeCloudEncoding("BigEndianIntegerCoder")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceTest.java new file mode 100644 index 000000000000..2ed4635e8c10 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/BigQuerySourceTest.java @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for BigQuerySource. + * + *

The tests just make sure a basic scenario of reading works because the class itself is a + * thin wrapper over {@code BigQueryTableRowIterator}. The tests for the wrapped class have + * comprehensive coverage. + */ +@RunWith(JUnit4.class) +public class BigQuerySourceTest { + + @Mock private Bigquery mockClient; + @Mock private Bigquery.Tables mockTables; + @Mock private Bigquery.Tables.Get mockTablesGet; + @Mock private Bigquery.Tabledata mockTabledata; + @Mock private Bigquery.Tabledata.List mockTabledataList; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockClient); + verifyNoMoreInteractions(mockTables); + verifyNoMoreInteractions(mockTablesGet); + verifyNoMoreInteractions(mockTabledata); + verifyNoMoreInteractions(mockTabledataList); + } + + private void onTableGet(Table table) throws IOException { + when(mockClient.tables()) + .thenReturn(mockTables); + when(mockTables.get(anyString(), anyString(), anyString())) + .thenReturn(mockTablesGet); + when(mockTablesGet.execute()) + .thenReturn(table); + } + + private void verifyTableGet() throws IOException { + verify(mockClient).tables(); + verify(mockTables).get("project", "dataset", "table"); + verify(mockTablesGet).execute(); + } + + private void onTableList(TableDataList result) throws IOException { + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())) + .thenReturn(mockTabledataList); + when(mockTabledataList.execute()) + .thenReturn(result); + } + + private void verifyTabledataList() throws IOException { + verify(mockClient, atLeastOnce()).tabledata(); + verify(mockTabledata, atLeastOnce()).list("project", "dataset", "table"); + verify(mockTabledataList, atLeastOnce()).execute(); + // Max results may be set when testing for an empty table. + verify(mockTabledataList, atLeast(0)).setMaxResults(anyLong()); + } + + private Table basicTableSchema() { + return new Table() + .setSchema(new TableSchema() + .setFields(Arrays.asList( + new TableFieldSchema() + .setName("name") + .setType("STRING"), + new TableFieldSchema() + .setName("integer") + .setType("INTEGER"), + new TableFieldSchema() + .setName("float") + .setType("FLOAT"), + new TableFieldSchema() + .setName("bool") + .setType("BOOLEAN") + ))); + } + + private TableRow rawRow(Object...args) { + List cells = new LinkedList<>(); + for (Object a : args) { + cells.add(new TableCell().setV(a)); + } + return new TableRow().setF(cells); + } + + private TableDataList rawDataList(TableRow...rows) { + return new TableDataList() + .setRows(Arrays.asList(rows)); + } + + @Test + public void testRead() throws IOException { + onTableGet(basicTableSchema()); + + // BQ API data is always encoded as a string + TableDataList dataList = rawDataList( + rawRow("Arthur", "42", "3.14159", "false"), + rawRow("Allison", "79", "2.71828", "true") + ); + onTableList(dataList); + + BigQuerySource source = new BigQuerySource( + mockClient, + new TableReference() + .setProjectId("project") + .setDatasetId("dataset") + .setTableId("table")); + + BigQuerySource.SourceIterator iterator = source.iterator(); + Assert.assertTrue(iterator.hasNext()); + TableRow row = iterator.next(); + + Assert.assertEquals("Arthur", row.get("name")); + Assert.assertEquals("42", row.get("integer")); + Assert.assertEquals(3.14159, row.get("float")); + Assert.assertEquals(false, row.get("bool")); + + row = iterator.next(); + + Assert.assertEquals("Allison", row.get("name")); + Assert.assertEquals("79", row.get("integer")); + Assert.assertEquals(2.71828, row.get("float")); + Assert.assertEquals(true, row.get("bool")); + + Assert.assertFalse(iterator.hasNext()); + + verifyTableGet(); + verifyTabledataList(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFnTest.java new file mode 100644 index 000000000000..b616f6d75f37 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CombineValuesFnTest.java @@ -0,0 +1,337 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoFn; +import com.google.cloud.dataflow.sdk.util.common.worker.Receiver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.reflect.TypeToken; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for CombineValuesFn. + */ +@RunWith(JUnit4.class) +public class CombineValuesFnTest { + /** Example AccumulatingCombineFn. */ + public static class MeanInts extends + Combine.AccumulatingCombineFn { + + class CountSum extends + Combine.AccumulatingCombineFn.Accumulator { + + long count; + double sum; + + @Override + public void addInput(Integer element) { + count++; + sum += element.doubleValue(); + } + + @Override + public void mergeAccumulator(CountSum accumulator) { + count += accumulator.count; + sum += accumulator.sum; + } + + @Override + public String extractOutput() { + return String.format("%.1f", count == 0 ? 0.0 : sum / count); + } + + public CountSum(long count, double sum) { + this.count = count; + this.sum = sum; + } + + @Override + public int hashCode() { + return KV.of(count, sum).hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof CountSum)) { + return false; + } + if (obj == this) { + return true; + } + + CountSum other = (CountSum) obj; + return (this.count == other.count) + && (Math.abs(this.sum - other.sum) < 0.1); + } + } + + @Override + public CountSum createAccumulator() { + return new CountSum(0, 0.0); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new CountSumCoder(); + } + } + + /** + * An example "cheap" accumulator coder. + */ + public static class CountSumCoder implements Coder { + public CountSumCoder() { } + + @Override + public void encode( + MeanInts.CountSum value, OutputStream outStream, Context context) + throws CoderException, IOException { + DataOutputStream dataStream = new DataOutputStream(outStream); + dataStream.writeLong(value.count); + dataStream.writeDouble(value.sum); + } + + @Override + public MeanInts.CountSum decode(InputStream inStream, Context context) + throws CoderException, IOException { + DataInputStream dataStream = new DataInputStream(inStream); + long count = dataStream.readLong(); + double sum = dataStream.readDouble(); + return (new MeanInts ()).new CountSum(count, sum); + } + + @Override + public boolean isDeterministic() { return true; } + + public CloudObject asCloudObject() { + return makeCloudEncoding(this.getClass().getName()); + } + + @Override + public List> getCoderArguments() { return null; } + + public List getInstanceComponents(MeanInts.CountSum exampleValue) { + return null; + } + + @Override + public boolean isRegisterByteSizeObserverCheap( + MeanInts.CountSum value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + MeanInts.CountSum value, ElementByteSizeObserver observer, Context ctx) + throws Exception { + observer.update((long) 16); + } + } + + static class TestReceiver implements Receiver { + List receivedElems = new ArrayList<>(); + + @Override + public void process(Object outputElem) { + receivedElems.add(outputElem); + } + } + + private static ParDoFn createCombineValuesFn( + String phase, Combine.KeyedCombineFn combineFn) throws Exception { + // This partially mirrors the work that + // com.google.cloud.dataflow.sdk.transforms.Combine.translateHelper + // does, at least for the KeyedCombineFn. The phase is generated + // by the back-end. + CloudObject spec = CloudObject.forClassName("CombineValuesFn"); + addString(spec, PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray(combineFn))); + addString(spec, PropertyNames.PHASE, phase); + + return CombineValuesFn.create( + PipelineOptionsFactory.create(), + spec, + "name", + null, // no side inputs + null, // no side outputs + 1, // single main output + new BatchModeExecutionContext(), + (new CounterSet()).getAddCounterMutator(), + null); + } + + @Test + public void testCombineValuesFnAll() throws Exception { + TestReceiver receiver = new TestReceiver(); + + Combine.KeyedCombineFn combiner = + (new MeanInts()).asKeyedFn(); + + ParDoFn combineParDoFn = createCombineValuesFn( + CombineValuesFn.CombinePhase.ALL, combiner); + + combineParDoFn.startBundle(receiver); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("a", Arrays.asList(5, 6, 7)))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("b", Arrays.asList(1, 3, 7)))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("c", Arrays.asList(3, 6, 8, 9)))); + combineParDoFn.finishBundle(); + + Object[] expectedReceivedElems = { + WindowedValue.valueInGlobalWindow(KV.of("a", "6.0")), + WindowedValue.valueInGlobalWindow(KV.of("b", "3.7")), + WindowedValue.valueInGlobalWindow(KV.of("c", "6.5")), + }; + assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray()); + } + + @Test + public void testCombineValuesFnAdd() throws Exception { + TestReceiver receiver = new TestReceiver(); + MeanInts mean = new MeanInts(); + + Combine.KeyedCombineFn combiner = mean.asKeyedFn(); + + ParDoFn combineParDoFn = createCombineValuesFn( + CombineValuesFn.CombinePhase.ADD, combiner); + + combineParDoFn.startBundle(receiver); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("a", Arrays.asList(5, 6, 7)))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("b", Arrays.asList(1, 3, 7)))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("c", Arrays.asList(3, 6, 8, 9)))); + combineParDoFn.finishBundle(); + + Object[] expectedReceivedElems = { + WindowedValue.valueInGlobalWindow(KV.of("a", mean.new CountSum(3, 18))), + WindowedValue.valueInGlobalWindow(KV.of("b", mean.new CountSum(3, 11))), + WindowedValue.valueInGlobalWindow(KV.of("c", mean.new CountSum(4, 26))) + }; + assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray()); + } + + @Test + public void testCombineValuesFnMerge() throws Exception { + TestReceiver receiver = new TestReceiver(); + MeanInts mean = new MeanInts(); + + Combine.KeyedCombineFn combiner = mean.asKeyedFn(); + + ParDoFn combineParDoFn = createCombineValuesFn( + CombineValuesFn.CombinePhase.MERGE, combiner); + + combineParDoFn.startBundle(receiver); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("a", + Arrays.asList( + mean.new CountSum(3, 6), + mean.new CountSum(2, 9), + mean.new CountSum(1, 12))))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("b", + Arrays.asList( + mean.new CountSum(2, 20), + mean.new CountSum(1, 1))))); + combineParDoFn.finishBundle(); + + Object[] expectedReceivedElems = { + WindowedValue.valueInGlobalWindow(KV.of("a", mean.new CountSum(6, 27))), + WindowedValue.valueInGlobalWindow(KV.of("b", mean.new CountSum(3, 21))), + }; + assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray()); + } + + @Test + public void testCombineValuesFnExtract() throws Exception { + TestReceiver receiver = new TestReceiver(); + MeanInts mean = new MeanInts(); + + Combine.KeyedCombineFn combiner = mean.asKeyedFn(); + + ParDoFn combineParDoFn = createCombineValuesFn( + CombineValuesFn.CombinePhase.EXTRACT, combiner); + + combineParDoFn.startBundle(receiver); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("a", mean.new CountSum(6, 27)))); + combineParDoFn.processElement(WindowedValue.valueInGlobalWindow( + KV.of("b", mean.new CountSum(3, 21)))); + combineParDoFn.finishBundle(); + + assertArrayEquals( + new Object[]{ WindowedValue.valueInGlobalWindow(KV.of("a", "4.5")), + WindowedValue.valueInGlobalWindow(KV.of("b", "7.0")) }, + receiver.receivedElems.toArray()); + } + + @Test + public void testCombineValuesFnCoders() throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + + MeanInts meanInts = new MeanInts(); + MeanInts.CountSum countSum = meanInts.new CountSum(6, 27); + + Coder coder = meanInts.getAccumulatorCoder( + registry, registry.getDefaultCoder(TypeToken.of(Integer.class))); + + assertEquals( + countSum, + CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, countSum))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannelTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannelTest.java new file mode 100644 index 000000000000..e27fa1832870 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/CopyableSeekableByteChannelTest.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Preconditions.checkArgument; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SeekableByteChannel; + +/** Unit tests for {@link CopyableSeekableByteChannel}. */ +@RunWith(JUnit4.class) +public final class CopyableSeekableByteChannelTest { + @Test + public void copiedChannelShouldMaintainIndependentPosition() + throws IOException { + ByteBuffer dst = ByteBuffer.allocate(6); + SeekableByteChannel base = + new FakeSeekableByteChannel("Hello, world! :-)".getBytes()); + base.position(1); + + CopyableSeekableByteChannel chan = new CopyableSeekableByteChannel(base); + assertThat(chan.position(), equalTo((long) 1)); + + CopyableSeekableByteChannel copy = chan.copy(); + assertThat(copy.position(), equalTo((long) 1)); + + assertThat(chan.read(dst), equalTo(6)); + assertThat(chan.position(), equalTo((long) 7)); + assertThat(new String(dst.array()), equalTo("ello, ")); + dst.rewind(); + + assertThat(copy.position(), equalTo((long) 1)); + copy.position(3); + assertThat(copy.read(dst), equalTo(6)); + assertThat(copy.position(), equalTo((long) 9)); + assertThat(new String(dst.array()), equalTo("lo, wo")); + dst.rewind(); + + assertThat(chan.read(dst), equalTo(6)); + assertThat(chan.position(), equalTo((long) 13)); + assertThat(new String(dst.array()), equalTo("world!")); + dst.rewind(); + + assertThat(chan.read(dst), equalTo(4)); + assertThat(chan.position(), equalTo((long) 17)); + assertThat(new String(dst.array()), equalTo(" :-)d!")); + dst.rewind(); + + assertThat(copy.position(), equalTo((long) 9)); + assertThat(copy.read(dst), equalTo(6)); + assertThat(new String(dst.array()), equalTo("rld! :")); + } + + private static final class FakeSeekableByteChannel + implements SeekableByteChannel { + private boolean closed = false; + private ByteBuffer data; + + public FakeSeekableByteChannel(byte[] data) { + this.data = ByteBuffer.wrap(data); + } + + @Override + public long position() throws IOException { + checkClosed(); + return data.position(); + } + + @Override + public SeekableByteChannel position(long newPosition) throws IOException { + checkArgument(newPosition >= 0); + checkClosed(); + data.position((int) newPosition); + return this; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + checkClosed(); + if (!data.hasRemaining()) { + return -1; + } + int count = Math.min(data.remaining(), dst.remaining()); + ByteBuffer src = data.slice(); + src.limit(count); + dst.put(src); + data.position(data.position() + count); + return count; + } + + @Override + public long size() throws IOException { + checkClosed(); + return data.limit(); + } + + @Override + public SeekableByteChannel truncate(long size) throws IOException { + checkClosed(); + data.limit((int) size); + return this; + } + + @Override + public int write(ByteBuffer src) throws IOException { + checkClosed(); + int count = Math.min(data.remaining(), src.remaining()); + ByteBuffer copySrc = src.slice(); + copySrc.limit(count); + data.put(copySrc); + return count; + } + + @Override + public boolean isOpen() { + return !closed; + } + + @Override + public void close() { + closed = true; + } + + private void checkClosed() throws ClosedChannelException { + if (closed) { + throw new ClosedChannelException(); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdaterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdaterTest.java new file mode 100644 index 000000000000..2167a504183f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkProgressUpdaterTest.java @@ -0,0 +1,438 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudPositionToSourcePosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.CloudCounterUtils.extractCounter; +import static com.google.cloud.dataflow.sdk.util.CloudMetricUtils.extractCloudMetric; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudTime; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SET; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.api.services.dataflow.model.Position; +import com.google.api.services.dataflow.model.WorkItem; +import com.google.api.services.dataflow.model.WorkItemServiceState; +import com.google.api.services.dataflow.model.WorkItemStatus; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterTestUtils; +import com.google.cloud.dataflow.sdk.util.common.Metric; +import com.google.cloud.dataflow.sdk.util.common.Metric.DoubleMetric; +import com.google.cloud.dataflow.sdk.util.common.worker.MapTaskExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.Operation; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; + +import org.hamcrest.Description; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatcher; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import javax.annotation.Nullable; + +/** Unit tests for {@link DataflowWorkProgressUpdater}. */ +@RunWith(JUnit4.class) +public class DataflowWorkProgressUpdaterTest { + static class TestMapTaskExecutor extends MapTaskExecutor { + ApproximateProgress progress = null; + + public TestMapTaskExecutor(CounterSet counters) { + super(new ArrayList(), + counters, + new StateSampler("test", counters.getAddCounterMutator())); + } + + @Override + public Source.Progress getWorkerProgress() { + return cloudProgressToSourceProgress(progress); + } + + @Override + public Source.Position proposeStopPosition( + Source.Progress suggestedStopPoint) { + @Nullable ApproximateProgress progress = sourceProgressToCloudProgress(suggestedStopPoint); + if (progress == null) { + return null; + } + return cloudPositionToSourcePosition(progress.getPosition()); + } + + public void setWorkerProgress(ApproximateProgress progress) { + this.progress = progress; + } + } + + static { + // To shorten wait times during testing. + System.setProperty("minimum_worker_update_interval_millis", "100"); + System.setProperty("worker_lease_renewal_latency_margin", "100"); + } + + private static final String PROJECT_ID = "TEST_PROJECT_ID"; + private static final String JOB_ID = "TEST_JOB_ID"; + private static final String WORKER_ID = "TEST_WORKER_ID"; + private static final Long WORK_ID = 1234567890L; + private static final String COUNTER_NAME = "test-counter-"; + private static final AggregationKind[] COUNTER_KINDS = {SUM, MAX, SET}; + private static final Long COUNTER_VALUE1 = 12345L; + private static final Double COUNTER_VALUE2 = Math.PI; + private static final String COUNTER_VALUE3 = "value"; + + @Rule public final ExpectedException thrown = ExpectedException.none(); + @Mock private DataflowWorker.WorkUnitClient workUnitClient; + private CounterSet counters; + private List> metrics; + private TestMapTaskExecutor worker; + private WorkItem workItem; + private DataflowWorkerHarnessOptions options; + private DataflowWorkProgressUpdater progressUpdater; + private long nowMillis; + + @Before + public void initMocksAndWorkflowServiceAndWorkerAndWork() throws IOException { + MockitoAnnotations.initMocks(this); + + options = PipelineOptionsFactory.createFromSystemProperties(); + options.setProject(PROJECT_ID); + options.setJobId(JOB_ID); + options.setWorkerId(WORKER_ID); + + metrics = new ArrayList<>(); + counters = new CounterSet(); + worker = new TestMapTaskExecutor(counters) { + @Override + public Collection> getOutputMetrics() { + return metrics; + } + }; + nowMillis = System.currentTimeMillis(); + + workItem = new WorkItem(); + workItem.setProjectId(PROJECT_ID); + workItem.setJobId(JOB_ID); + workItem.setId(WORK_ID); + workItem.setLeaseExpireTime(toCloudTime(new Instant(nowMillis + 1000))); + workItem.setReportStatusInterval(toCloudDuration(Duration.millis(500))); + + progressUpdater = new DataflowWorkProgressUpdater( + workItem, worker, workUnitClient, options); + } + + // TODO: Remove sleeps from this test by using a mock sleeper. This + // requires a redesign of the WorkProgressUpdater to use a Sleeper and + // not use a ScheduledThreadExecutor which relies on real time passing. + @Test(timeout = 2000) + public void workProgressUpdaterUpdates() throws Exception { + when(workUnitClient.reportWorkItemStatus(any(WorkItemStatus.class))).thenReturn( + generateServiceState(nowMillis + 2000, 1000, null)); + setUpCounters(2); + setUpMetrics(3); + setUpProgress(makeRecordIndexProgress(1L)); + progressUpdater.startReportingProgress(); + // The initial update should be sent after leaseRemainingTime / 2. + verify(workUnitClient, timeout(600)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withCounters(2) + .withMetrics(3) + .withProgress(makeRecordIndexProgress(1L)))); + progressUpdater.stopReportingProgress(); + } + + // Verifies that ReportWorkItemStatusRequest contains correct progress report + // and actual stop position report. + @Test(timeout = 5000) + public void workProgressUpdaterAdaptsProgressInterval() throws Exception { + // Mock that the next reportProgress call will return a response that asks + // us to truncate the task at index 3, and the next two will not ask us to + // truncate at all. + when(workUnitClient.reportWorkItemStatus(any(WorkItemStatus.class))) + .thenReturn(generateServiceState(nowMillis + 2000, 1000, + makeRecordIndexPosition(3L))) + .thenReturn(generateServiceState(nowMillis + 3000, 2000, null)) + .thenReturn(generateServiceState(nowMillis + 4000, 3000, null)); + + setUpCounters(3); + setUpMetrics(2); + setUpProgress(makeRecordIndexProgress(1L)); + progressUpdater.startReportingProgress(); + // The initial update should be sent after + // leaseRemainingTime (1000) / 2 = 500. + verify(workUnitClient, timeout(600)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withCounters(3) + .withMetrics(2) + .withProgress(makeRecordIndexProgress(1L)))); + + setUpCounters(5); + setUpMetrics(6); + setUpProgress(makeRecordIndexProgress(2L)); + // The second update should be sent after one second (2000 / 2). + verify(workUnitClient, timeout(1100)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withCounters(5) + .withMetrics(6) + .withProgress(makeRecordIndexProgress(2L)) + .withStopPosition(makeRecordIndexPosition(3L)))); + + // After the request is sent, reset stop position cache to null. + assertNull(progressUpdater.getStopPosition()); + + setUpProgress(makeRecordIndexProgress(3L)); + + // The third update should be sent after one and half seconds (3000 / 2). + verify(workUnitClient, timeout(1600)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withProgress(makeRecordIndexProgress(3L)))); + + progressUpdater.stopReportingProgress(); + } + + // Verifies that a last update is sent when there is an unacknowledged split request. + @Test(timeout = 3000) + public void workProgressUpdaterLastUpdate() throws Exception { + when(workUnitClient.reportWorkItemStatus(any(WorkItemStatus.class))) + .thenReturn(generateServiceState(nowMillis + 2000, 1000, + makeRecordIndexPosition(2L))) + .thenReturn(generateServiceState(nowMillis + 3000, 2000, null)); + + setUpProgress(makeRecordIndexProgress(1L)); + progressUpdater.startReportingProgress(); + // The initial update should be sent after leaseRemainingTime / 2 = 500 msec. + Thread.sleep(600); + verify(workUnitClient, timeout(200)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withProgress(makeRecordIndexProgress(1L)))); + + // The first update should include the new actual stop position. + // Verify that the progressUpdater has recorded it. + assertEquals(makeRecordIndexPosition(2L), + sourcePositionToCloudPosition(progressUpdater.getStopPosition())); + + setUpProgress(makeRecordIndexProgress(2L)); + // The second update should be sent after one second (2000 / 2). + Thread.sleep(200); // not enough time for an update so the latest stop position is not + // acknowledged. + // Check that the progressUpdater still has a pending stop position to send + assertEquals(makeRecordIndexPosition(2L), + sourcePositionToCloudPosition(progressUpdater.getStopPosition())); + + progressUpdater.stopReportingProgress(); // should send the last update + // check that the progressUpdater is done with reporting its latest stop position + assertNull(progressUpdater.getStopPosition()); + + // Verify that the last update contained the latest stop position + verify(workUnitClient, timeout(1000)).reportWorkItemStatus(argThat( + new ExpectedDataflowProgress() + .withStopPosition(makeRecordIndexPosition(2L)))); + } + + private void setUpCounters(int n) { + counters.clear(); + for (int i = 0; i < n; i++) { + counters.add(makeCounter(i)); + } + } + + private static Counter makeCounter(int i) { + if (i % 3 == 0) { + return Counter.longs(COUNTER_NAME + i, COUNTER_KINDS[0]) + .addValue(COUNTER_VALUE1 + i).addValue(COUNTER_VALUE1 + i * 2); + } else if (i % 3 == 1) { + return Counter.doubles(COUNTER_NAME + i, COUNTER_KINDS[1]) + .addValue(COUNTER_VALUE2 + i).addValue(COUNTER_VALUE2 + i * 3); + } else { + return Counter.strings(COUNTER_NAME + i, COUNTER_KINDS[2]) + .addValue(COUNTER_VALUE3 + i).addValue(COUNTER_NAME + i * 5); + } + } + + private static Metric makeMetric(int i) { + return new DoubleMetric(String.valueOf(i), (double) i); + } + + private void setUpMetrics(int n) { + metrics = new ArrayList<>(); + for (int i = 0; i < n; i++) { + metrics.add(makeMetric(i)); + } + } + + private void setUpProgress(ApproximateProgress progress) { + worker.setWorkerProgress(progress); + } + + private com.google.api.services.dataflow.model.Position makeRecordIndexPosition(Long index) { + com.google.api.services.dataflow.model.Position position = + new com.google.api.services.dataflow.model.Position(); + position.setRecordIndex(index); + return position; + } + + private ApproximateProgress makeRecordIndexProgress(Long index) { + return new ApproximateProgress().setPosition(makeRecordIndexPosition(index)); + } + + private WorkItemServiceState generateServiceState( + long leaseExpirationTimestamp, int progressReportIntervalMs, + Position suggestedStopPosition) + throws IOException { + WorkItemServiceState responseState = new WorkItemServiceState(); + responseState.setFactory(Transport.getJsonFactory()); + responseState.setLeaseExpireTime(toCloudTime(new Instant(leaseExpirationTimestamp))); + responseState.setReportStatusInterval( + toCloudDuration(Duration.millis(progressReportIntervalMs))); + + if (suggestedStopPosition != null) { + responseState.setSuggestedStopPosition(suggestedStopPosition); + } + + return responseState; + } + + private static final class ExpectedDataflowProgress extends ArgumentMatcher { + @Nullable Integer counterCount; + @Nullable Integer metricCount; + @Nullable ApproximateProgress expectedProgress; + @Nullable Position expectedStopPosition; + + public ExpectedDataflowProgress withCounters(Integer counterCount) { + this.counterCount = counterCount; + return this; + } + + public ExpectedDataflowProgress withMetrics(Integer metricCount) { + this.metricCount = metricCount; + return this; + } + + public ExpectedDataflowProgress withProgress(ApproximateProgress expectedProgress) { + this.expectedProgress = expectedProgress; + return this; + } + + public ExpectedDataflowProgress withStopPosition(Position expectedStopPosition) { + this.expectedStopPosition = expectedStopPosition; + return this; + } + + @Override + public void describeTo(Description description) { + List values = new ArrayList<>(); + if (this.counterCount != null) { + for (int i = 0; i < counterCount; i++) { + values.add(extractCounter(makeCounter(i), false).toString()); + } + } + if (this.metricCount != null) { + for (int i = 0; i < metricCount; i++) { + values.add(extractCloudMetric(makeMetric(i), WORKER_ID).toString()); + } + } + if (this.expectedProgress != null) { + values.add("progress " + this.expectedProgress); + } + if (this.expectedStopPosition != null) { + values.add("stop position " + this.expectedStopPosition); + } else { + values.add("no stop position present"); + } + description.appendValueList("Dataflow progress with ", ", ", ".", values); + } + + @Override + public boolean matches(Object status) { + WorkItemStatus st = (WorkItemStatus) status; + return matchCountersAndMetrics(st) + && matchProgress(st) + && matchStopPosition(st); + } + + private boolean matchCountersAndMetrics(WorkItemStatus status) { + if (counterCount == null && metricCount == null) { + return true; + } + + List sentUpdates = status.getMetricUpdates(); + + if (counterCount + metricCount != sentUpdates.size()) { + return false; + } + + for (int i = 0; i < counterCount; i++) { + if (!sentUpdates.contains( + CounterTestUtils.extractCounterUpdate(makeCounter(i), false))) { + return false; + } + } + + for (int i = 0; i < metricCount; i++) { + if (!sentUpdates.contains(extractCloudMetric(makeMetric(i), WORKER_ID))) { + return false; + } + } + + return true; + } + + private boolean matchProgress(WorkItemStatus status) { + if (expectedProgress == null) { + return true; + } + ApproximateProgress progress = status.getProgress(); + return expectedProgress.equals(progress); + } + + private boolean matchStopPosition(WorkItemStatus status) { + Position actualStopPosition = status.getStopPosition(); + if (expectedStopPosition == null) { + return actualStopPosition == null; + } + return expectedStopPosition.equals(actualStopPosition); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarnessTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarnessTest.java new file mode 100644 index 000000000000..d1d369fe99ac --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerHarnessTest.java @@ -0,0 +1,243 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.Json; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.LeaseWorkItemRequest; +import com.google.api.services.dataflow.model.LeaseWorkItemResponse; +import com.google.api.services.dataflow.model.WorkItem; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.RestoreMappedDiagnosticContext; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.slf4j.MDC; + +import java.io.IOException; + +/** Unit tests for {@link DataflowWorkerHarness}. */ +@RunWith(JUnit4.class) +public class DataflowWorkerHarnessTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public TestRule restoreMDC = new RestoreMappedDiagnosticContext(); + @Rule public ExpectedException expectedException = ExpectedException.none(); + @Mock private MockHttpTransport transport; + @Mock private MockLowLevelHttpRequest request; + @Mock private DataflowWorker mockDataflowWorker; + + private Dataflow service; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + when(transport.buildRequest(anyString(), anyString())).thenReturn(request); + doCallRealMethod().when(request).getContentAsString(); + + service = new Dataflow(transport, Transport.getJsonFactory(), null); + } + + @Test + public void testThatWeOnlyProcessWorkOnce() throws Exception { + when(mockDataflowWorker.getAndPerformWork()).thenReturn(true); + DataflowWorkerHarness.processWork(mockDataflowWorker); + verify(mockDataflowWorker).getAndPerformWork(); + verifyNoMoreInteractions(mockDataflowWorker); + } + + @Test + public void testThatWeOnlyProcessWorkOnceEvenWhenFailing() throws Exception { + when(mockDataflowWorker.getAndPerformWork()).thenReturn(false); + DataflowWorkerHarness.processWork(mockDataflowWorker); + verify(mockDataflowWorker).getAndPerformWork(); + verifyNoMoreInteractions(mockDataflowWorker); + } + + @Test + public void testCreationOfWorkerHarness() throws Exception { + System.getProperties().putAll(ImmutableMap + .builder() + .put("project_id", "projectId") + .put("job_id", "jobId") + .put("worker_id", "workerId") + .build()); + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + options.setGcpCredential(new TestCredential()); + assertNotNull(DataflowWorkerHarness.create(options)); + assertEquals("jobId", MDC.get("dataflow.jobId")); + assertEquals("workerId", MDC.get("dataflow.workerId")); + } + + @Test + public void testCloudServiceCall() throws Exception { + System.getProperties().putAll(ImmutableMap + .builder() + .put("project_id", "projectId") + .put("job_id", "jobId") + .put("worker_id", "workerId") + .build()); + WorkItem workItem = createWorkItem("projectId", "jobId"); + + when(request.execute()).thenReturn(generateMockResponse(workItem)); + + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + + DataflowWorker.WorkUnitClient client = + new DataflowWorkerHarness.DataflowWorkUnitClient(service, options); + + assertEquals(workItem, client.getWorkItem()); + + LeaseWorkItemRequest actualRequest = Transport.getJsonFactory().fromString( + request.getContentAsString(), LeaseWorkItemRequest.class); + assertEquals("workerId", actualRequest.getWorkerId()); + assertEquals(ImmutableList.of("workerId", "remote_source", "custom_source"), + actualRequest.getWorkerCapabilities()); + assertEquals(ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + actualRequest.getWorkItemTypes()); + assertEquals("1234", MDC.get("dataflow.workId")); + } + + @Test + public void testCloudServiceCallNoWorkId() throws Exception { + System.getProperties().putAll(ImmutableMap + .builder() + .put("project_id", "projectId") + .put("job_id", "jobId") + .put("worker_id", "workerId") + .build()); + + // If there's no work the service should return an empty work item. + WorkItem workItem = new WorkItem(); + + when(request.execute()).thenReturn(generateMockResponse(workItem)); + + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + + DataflowWorker.WorkUnitClient client = + new DataflowWorkerHarness.DataflowWorkUnitClient(service, options); + + assertNull(client.getWorkItem()); + + LeaseWorkItemRequest actualRequest = Transport.getJsonFactory().fromString( + request.getContentAsString(), LeaseWorkItemRequest.class); + assertEquals("workerId", actualRequest.getWorkerId()); + assertEquals(ImmutableList.of("workerId", "remote_source", "custom_source"), + actualRequest.getWorkerCapabilities()); + assertEquals(ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + actualRequest.getWorkItemTypes()); + } + + @Test + public void testCloudServiceCallNoWorkItem() throws Exception { + System.getProperties().putAll(ImmutableMap + .builder() + .put("project_id", "projectId") + .put("job_id", "jobId") + .put("worker_id", "workerId") + .build()); + + when(request.execute()).thenReturn(generateMockResponse()); + + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + + DataflowWorker.WorkUnitClient client = + new DataflowWorkerHarness.DataflowWorkUnitClient(service, options); + + assertNull(client.getWorkItem()); + + LeaseWorkItemRequest actualRequest = Transport.getJsonFactory().fromString( + request.getContentAsString(), LeaseWorkItemRequest.class); + assertEquals("workerId", actualRequest.getWorkerId()); + assertEquals(ImmutableList.of("workerId", "remote_source", "custom_source"), + actualRequest.getWorkerCapabilities()); + assertEquals(ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + actualRequest.getWorkItemTypes()); + } + + @Test + public void testCloudServiceCallMultipleWorkItems() throws Exception { + expectedException.expect(IOException.class); + expectedException.expectMessage( + "This version of the SDK expects no more than one work item from the service"); + System.getProperties().putAll(ImmutableMap + .builder() + .put("project_id", "projectId") + .put("job_id", "jobId") + .put("worker_id", "workerId") + .build()); + + WorkItem workItem1 = createWorkItem("projectId", "jobId"); + WorkItem workItem2 = createWorkItem("projectId", "jobId"); + + when(request.execute()).thenReturn(generateMockResponse(workItem1, workItem2)); + + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + + DataflowWorker.WorkUnitClient client = + new DataflowWorkerHarness.DataflowWorkUnitClient(service, options); + + client.getWorkItem(); + } + + private LowLevelHttpResponse generateMockResponse(WorkItem ... workItems) throws Exception { + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + response.setContentType(Json.MEDIA_TYPE); + LeaseWorkItemResponse lease = new LeaseWorkItemResponse(); + lease.setWorkItems(Lists.newArrayList(workItems)); + // N.B. Setting the factory is necessary in order to get valid JSON. + lease.setFactory(Transport.getJsonFactory()); + response.setContent(lease.toPrettyString()); + return response; + } + + private WorkItem createWorkItem(String projectId, String jobId) { + WorkItem workItem = new WorkItem(); + workItem.setFactory(Transport.getJsonFactory()); + workItem.setProjectId(projectId); + workItem.setJobId(jobId); + + // We need to set a work id because otherwise the client will treat the response as + // indicating no work is available. + workItem.setId(1234L); + return workItem; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerTest.java new file mode 100644 index 000000000000..2d51fb283895 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/DataflowWorkerTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.argThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.model.WorkItem; +import com.google.api.services.dataflow.model.WorkItemStatus; +import com.google.cloud.dataflow.sdk.options.DataflowWorkerHarnessOptions; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Unit tests for {@link DataflowWorker}. */ +@RunWith(JUnit4.class) +public class DataflowWorkerTest { + @Rule + public FastNanoClockAndSleeper clockAndSleeper = new FastNanoClockAndSleeper(); + + @Mock + DataflowWorker.WorkUnitClient mockWorkUnitClient; + + @Mock + DataflowWorkerHarnessOptions options; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testWhenNoWorkThatWeReturnFalse() throws Exception { + DataflowWorker worker = new DataflowWorker(mockWorkUnitClient, options); + when(mockWorkUnitClient.getWorkItem()).thenReturn(null); + + assertFalse(worker.getAndPerformWork()); + } + + @Test + public void testWhenProcessingWorkUnitFailsWeReportStatus() throws Exception { + DataflowWorker worker = new DataflowWorker(mockWorkUnitClient, options); + when(mockWorkUnitClient.getWorkItem()).thenReturn(new WorkItem().setId(1L)).thenReturn(null); + + assertFalse(worker.getAndPerformWork()); + verify(mockWorkUnitClient).reportWorkItemStatus(argThat(cloudWorkHasErrors())); + } + + private Matcher cloudWorkHasErrors() { + return new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText("WorkItemStatus expected to have errors"); + } + + @Override + protected boolean matchesSafely(WorkItemStatus status) { + return status.getCompleted() && !status.getErrors().isEmpty(); + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceTest.java new file mode 100644 index 000000000000..b41bd1b2e291 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/GroupingShuffleSourceTest.java @@ -0,0 +1,499 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.encodeBase64URLSafeString; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudDuration; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.Position; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Reiterable; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.util.common.worker.Source.SourceIterator; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.Lists; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Tests for GroupingShuffleSource. + */ +@RunWith(JUnit4.class) +public class GroupingShuffleSourceTest { + static final List>> NO_KVS = Collections.emptyList(); + + static final Instant timestamp = new Instant(123000); + static final IntervalWindow window = new IntervalWindow(timestamp, timestamp.plus(1000)); + + static final List>> KVS = Arrays.asList( + KV.of(1, Arrays.asList("in 1a", "in 1b")), + KV.of(2, Arrays.asList("in 2a", "in 2b")), + KV.of(3, Arrays.asList("in 3")), + KV.of(4, Arrays.asList("in 4a", "in 4b", "in 4c", "in 4d")), + KV.of(5, Arrays.asList("in 5"))); + + /** How many of the values with each key are to be read. */ + enum ValuesToRead { + /** Don't even ask for the values iterator. */ + SKIP_VALUES, + /** Get the iterator, but don't read any values. */ + READ_NO_VALUES, + /** Read just the first value. */ + READ_ONE_VALUE, + /** Read all the values. */ + READ_ALL_VALUES + } + + void runTestReadShuffleSource(List>> input, + ValuesToRead valuesToRead) + throws Exception { + Coder> elemCoder = + WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder()); + BatchModeExecutionContext context = new BatchModeExecutionContext(); + GroupingShuffleSource> shuffleSource = + new GroupingShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + WindowedValue.getFullCoder( + KvCoder.of( + BigEndianIntegerCoder.of(), + IterableCoder.of( + WindowedValue.getFullCoder(StringUtf8Coder.of(), + IntervalWindow.getCoder()))), + IntervalWindow.getCoder()), + context); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(shuffleSource); + + TestShuffleReader shuffleReader = new TestShuffleReader(); + List expectedSizes = new ArrayList<>(); + for (KV> kvs : input) { + Integer key = kvs.getKey(); + byte[] keyByte = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), key); + + for (String value : kvs.getValue()) { + byte[] valueByte = CoderUtils.encodeToByteArray( + elemCoder, WindowedValue.of(value, timestamp, Lists.newArrayList(window))); + byte[] skey = CoderUtils.encodeToByteArray(InstantCoder.of(), timestamp); + ShuffleEntry shuffleEntry = new ShuffleEntry(keyByte, skey, valueByte); + shuffleReader.addEntry(shuffleEntry); + expectedSizes.add(shuffleEntry.length()); + } + } + + List>>> actual = new ArrayList<>(); + try (SourceIterator>>>> iter = + shuffleSource.iterator(shuffleReader)) { + Iterable> prevValuesIterable = null; + Iterator> prevValuesIterator = null; + while (iter.hasNext()) { + Assert.assertTrue(iter.hasNext()); + Assert.assertTrue(iter.hasNext()); + + KV>> elem = iter.next().getValue(); + Integer key = elem.getKey(); + List> values = new ArrayList<>(); + if (valuesToRead.ordinal() > ValuesToRead.SKIP_VALUES.ordinal()) { + if (prevValuesIterable != null) { + prevValuesIterable.iterator(); // Verifies that this does not throw. + } + if (prevValuesIterator != null) { + prevValuesIterator.hasNext(); // Verifies that this does not throw. + } + + Iterable> valuesIterable = elem.getValue(); + Iterator> valuesIterator = valuesIterable.iterator(); + + if (valuesToRead.ordinal() >= ValuesToRead.READ_ONE_VALUE.ordinal()) { + while (valuesIterator.hasNext()) { + Assert.assertTrue(valuesIterator.hasNext()); + Assert.assertTrue(valuesIterator.hasNext()); + Assert.assertEquals("BatchModeExecutionContext key", + key, context.getKey()); + values.add(valuesIterator.next()); + if (valuesToRead == ValuesToRead.READ_ONE_VALUE) { + break; + } + } + if (valuesToRead == ValuesToRead.READ_ALL_VALUES) { + Assert.assertFalse(valuesIterator.hasNext()); + Assert.assertFalse(valuesIterator.hasNext()); + + try { + valuesIterator.next(); + Assert.fail("Expected NoSuchElementException"); + } catch (NoSuchElementException exn) { + // As expected. + } + valuesIterable.iterator(); // Verifies that this does not throw. + } + } + + prevValuesIterable = valuesIterable; + prevValuesIterator = valuesIterator; + } + + actual.add(KV.of(key, values)); + } + Assert.assertFalse(iter.hasNext()); + Assert.assertFalse(iter.hasNext()); + try { + iter.next(); + Assert.fail("Expected NoSuchElementException"); + } catch (NoSuchElementException exn) { + // As expected. + } + } + + List>>> expected = new ArrayList<>(); + for (KV> kvs : input) { + Integer key = kvs.getKey(); + List> values = new ArrayList<>(); + if (valuesToRead.ordinal() >= ValuesToRead.READ_ONE_VALUE.ordinal()) { + for (String value : kvs.getValue()) { + values.add(WindowedValue.of(value, timestamp, Lists.newArrayList(window))); + if (valuesToRead == ValuesToRead.READ_ONE_VALUE) { + break; + } + } + } + expected.add(KV.of(key, values)); + } + Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + @Test + public void testReadEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(NO_KVS, ValuesToRead.READ_ALL_VALUES); + } + + @Test + public void testReadEmptyShuffleSourceSkippingValues() throws Exception { + runTestReadShuffleSource(NO_KVS, ValuesToRead.SKIP_VALUES); + } + + @Test + public void testReadNonEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(KVS, ValuesToRead.READ_ALL_VALUES); + } + + @Test + public void testReadNonEmptyShuffleSourceReadingOneValue() throws Exception { + runTestReadShuffleSource(KVS, ValuesToRead.READ_ONE_VALUE); + } + + @Test + public void testReadNonEmptyShuffleSourceReadingNoValues() throws Exception { + runTestReadShuffleSource(KVS, ValuesToRead.READ_NO_VALUES); + } + + @Test + public void testReadNonEmptyShuffleSourceSkippingValues() throws Exception { + runTestReadShuffleSource(KVS, ValuesToRead.SKIP_VALUES); + } + + static byte[] fabricatePosition(int shard, byte[] key) throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.writeInt(shard); + if (key != null) { + dos.writeInt(Arrays.hashCode(key)); + } + return os.toByteArray(); + } + + @Test + public void testReadFromEmptyShuffleSourceAndUpdateStopPosition() + throws Exception { + BatchModeExecutionContext context = new BatchModeExecutionContext(); + GroupingShuffleSource shuffleSource = + new GroupingShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + WindowedValue.getFullCoder( + KvCoder.of( + BigEndianIntegerCoder.of(), + IterableCoder.of(BigEndianIntegerCoder.of())), + IntervalWindow.getCoder()), + context); + TestShuffleReader shuffleReader = new TestShuffleReader(); + try (Source.SourceIterator>>> iter = + shuffleSource.iterator(shuffleReader)) { + + Position proposedStopPosition = new Position(); + String stop = encodeBase64URLSafeString(fabricatePosition(0, null)); + proposedStopPosition.setShufflePosition(stop); + + // Cannot update stop position since all input was consumed. + Assert.assertEquals(null, iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + } + } + + @Test + public void testReadFromShuffleSourceAndFailToUpdateStopPosition() + throws Exception { + BatchModeExecutionContext context = new BatchModeExecutionContext(); + final int kFirstShard = 0; + + TestShuffleReader shuffleReader = new TestShuffleReader(); + final int kNumRecords = 2; + for (int i = 0; i < kNumRecords; ++i) { + byte[] key = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i); + shuffleReader.addEntry(new ShuffleEntry( + fabricatePosition(kFirstShard, key), key, null, key)); + } + + // Note that TestShuffleReader start/end positions are in the + // space of keys not the positions (TODO: should probably always + // use positions instead). + String stop = encodeBase64URLSafeString( + fabricatePosition(kNumRecords, null)); + GroupingShuffleSource shuffleSource = + new GroupingShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, stop, + WindowedValue.getFullCoder( + KvCoder.of( + BigEndianIntegerCoder.of(), + IterableCoder.of(BigEndianIntegerCoder.of())), + IntervalWindow.getCoder()), + context); + + try (Source.SourceIterator>>> iter = + shuffleSource.iterator(shuffleReader)) { + + Position proposedStopPosition = new Position(); + proposedStopPosition.setShufflePosition( + encodeBase64URLSafeString(fabricatePosition(kNumRecords + 1, null))); + + // Cannot update the stop position since the value provided is + // past the current stop position. + Assert.assertEquals(null, iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + + int i = 0; + for (; iter.hasNext(); ++i) { + KV> elem = iter.next().getValue(); + if (i == 0) { + // First record + byte[] key = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i); + proposedStopPosition.setShufflePosition( + encodeBase64URLSafeString(fabricatePosition(kFirstShard, key))); + // Cannot update stop position since it is identical with + // the position of the record that was just returned. + Assert.assertEquals(null, iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + + proposedStopPosition.setShufflePosition( + encodeBase64URLSafeString(fabricatePosition(kFirstShard, null))); + // Cannot update stop position since it comes before current position + Assert.assertEquals(null, iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + } + } + Assert.assertEquals(kNumRecords, i); + + proposedStopPosition.setShufflePosition( + encodeBase64URLSafeString(fabricatePosition(kFirstShard, null))); + // Cannot update stop position since all input was consumed. + Assert.assertEquals(null, iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + } + } + + @Test + public void testReadFromShuffleSourceAndUpdateStopPosition() + throws Exception { + BatchModeExecutionContext context = new BatchModeExecutionContext(); + GroupingShuffleSource shuffleSource = + new GroupingShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + WindowedValue.getFullCoder( + KvCoder.of( + BigEndianIntegerCoder.of(), + IterableCoder.of(BigEndianIntegerCoder.of())), + IntervalWindow.getCoder()), + context); + + TestShuffleReader shuffleReader = new TestShuffleReader(); + final int kNumRecords = 10; + final int kFirstShard = 0; + final int kSecondShard = 1; + + // Setting up two shards with kNumRecords each; keys are unique + // (hence groups of values for the same key are singletons) + // therefore each record comes with a unique position constructed. + for (int i = 0; i < kNumRecords; ++i) { + byte[] keyByte = CoderUtils.encodeToByteArray( + BigEndianIntegerCoder.of(), i); + ShuffleEntry entry = new ShuffleEntry( + fabricatePosition(kFirstShard, keyByte), keyByte, null, keyByte); + shuffleReader.addEntry(entry); + } + + for (int i = kNumRecords; i < 2 * kNumRecords; ++i) { + byte[] keyByte = CoderUtils.encodeToByteArray( + BigEndianIntegerCoder.of(), i); + + ShuffleEntry entry = new ShuffleEntry( + fabricatePosition(kSecondShard, keyByte), keyByte, null, keyByte); + shuffleReader.addEntry(entry); + } + + int i = 0; + try (Source.SourceIterator>>> iter = + shuffleSource.iterator(shuffleReader)) { + + Position proposedStopPosition = new Position(); + + Assert.assertNull(iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + + // Stop at the shard boundary + String stop = encodeBase64URLSafeString(fabricatePosition(kSecondShard, null)); + proposedStopPosition.setShufflePosition(stop); + + Assert.assertEquals( + stop, + sourcePositionToCloudPosition( + iter.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))) + .getShufflePosition()); + + while (iter.hasNext()) { + Assert.assertTrue(iter.hasNext()); + Assert.assertTrue(iter.hasNext()); + + KV> elem = iter.next().getValue(); + int key = elem.getKey(); + Assert.assertEquals(key, i); + + Iterable valuesIterable = elem.getValue(); + Iterator valuesIterator = valuesIterable.iterator(); + + int j = 0; + while (valuesIterator.hasNext()) { + Assert.assertTrue(valuesIterator.hasNext()); + Assert.assertTrue(valuesIterator.hasNext()); + + int value = valuesIterator.next(); + Assert.assertEquals(value, i); + ++j; + } + Assert.assertEquals(j, 1); + ++i; + } + + ApproximateProgress progress = + sourceProgressToCloudProgress(iter.getProgress()); + Assert.assertEquals(stop, progress.getPosition().getShufflePosition()); + } + Assert.assertEquals(i, kNumRecords); + } + + @Test + public void testGetApproximateProgress() throws Exception { + // Store the positions of all KVs returned. + List positionsList = new ArrayList(); + + BatchModeExecutionContext context = new BatchModeExecutionContext(); + GroupingShuffleSource shuffleSource = + new GroupingShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + WindowedValue.getFullCoder( + KvCoder.of( + BigEndianIntegerCoder.of(), + IterableCoder.of(BigEndianIntegerCoder.of())), + IntervalWindow.getCoder()), + context); + + TestShuffleReader shuffleReader = new TestShuffleReader(); + final int kNumRecords = 10; + + for (int i = 0; i < kNumRecords; ++i) { + byte[] position = fabricatePosition(i, null); + byte[] keyByte = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i); + positionsList.add(position); + ShuffleEntry entry = new ShuffleEntry(position, keyByte, null, keyByte); + shuffleReader.addEntry(entry); + } + + try (Source.SourceIterator>>> sourceIterator = + shuffleSource.iterator(shuffleReader)) { + Integer i = 0; + while (sourceIterator.hasNext()) { + Assert.assertTrue(sourceIterator.hasNext()); + ApproximateProgress progress = sourceProgressToCloudProgress(sourceIterator.getProgress()); + Assert.assertNotNull(progress.getPosition().getShufflePosition()); + + // Compare returned position with the expected position. + Assert.assertEquals(ByteArrayShufflePosition.of(positionsList.get(i)).encodeBase64(), + progress.getPosition().getShufflePosition()); + + WindowedValue>> elem = sourceIterator.next(); + Assert.assertEquals(i, elem.getValue().getKey()); + i++; + } + Assert.assertFalse(sourceIterator.hasNext()); + + ApproximateProgress finalProgress = + sourceProgressToCloudProgress(sourceIterator.getProgress()); + Assert.assertEquals(1.0, + (float) finalProgress.getPercentComplete(), 0.000000001); + Assert.assertEquals(Duration.ZERO, fromCloudDuration(finalProgress.getRemainingTime())); + } + } + + private ApproximateProgress createApproximateProgress( + com.google.api.services.dataflow.model.Position position) { + return new ApproximateProgress().setPosition(position); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactoryTest.java new file mode 100644 index 000000000000..64cf4f552021 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceFactoryTest.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.InMemorySourceTest.encodedElements; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for InMemorySourceFactory. + */ +@RunWith(JUnit4.class) +public class InMemorySourceFactoryTest { + static com.google.api.services.dataflow.model.Source createInMemoryCloudSource( + List elements, + Long start, + Long end, + Coder coder) + throws Exception { + List encodedElements = encodedElements(elements, coder); + + CloudObject spec = CloudObject.forClassName("InMemorySource"); + addStringList(spec, PropertyNames.ELEMENTS, encodedElements); + + if (start != null) { + addLong(spec, PropertyNames.START_INDEX, start); + } + if (end != null) { + addLong(spec, PropertyNames.END_INDEX, end); + } + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(coder.asCloudObject()); + + return cloudSource; + } + + void runTestCreateInMemorySource(List elements, + Long start, + Long end, + int expectedStart, + int expectedEnd, + Coder coder) + throws Exception { + com.google.api.services.dataflow.model.Source cloudSource = + createInMemoryCloudSource(elements, start, end, coder); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), cloudSource, + new BatchModeExecutionContext()); + Assert.assertThat(source, new IsInstanceOf(InMemorySource.class)); + InMemorySource inMemorySource = (InMemorySource) source; + Assert.assertEquals(encodedElements(elements, coder), + inMemorySource.encodedElements); + Assert.assertEquals(expectedStart, inMemorySource.startIndex); + Assert.assertEquals(expectedEnd, inMemorySource.endIndex); + Assert.assertEquals(coder, inMemorySource.coder); + } + + @Test + public void testCreatePlainInMemorySource() throws Exception { + runTestCreateInMemorySource( + Arrays.asList("hi", "there", "bob"), + null, null, + 0, 3, + StringUtf8Coder.of()); + } + + @Test + public void testCreateRichInMemorySource() throws Exception { + runTestCreateInMemorySource( + Arrays.asList(33, 44, 55, 66, 77, 88), + 1L, 3L, + 1, 3, + BigEndianIntegerCoder.of()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceTest.java new file mode 100644 index 000000000000..d7574c517b4e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/InMemorySourceTest.java @@ -0,0 +1,236 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.Position; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for InMemorySource. + */ +@RunWith(JUnit4.class) +public class InMemorySourceTest { + static List encodedElements(List elements, Coder coder) + throws Exception { + List encodedElements = new ArrayList<>(); + for (T element : elements) { + byte[] encodedElement = encodeToByteArray(coder, element); + String encodedElementString = byteArrayToJsonString(encodedElement); + encodedElements.add(encodedElementString); + } + return encodedElements; + } + + void runTestReadInMemorySource(List elements, + Long startIndex, + Long endIndex, + List expectedElements, + List expectedSizes, + Coder coder) + throws Exception { + InMemorySource inMemorySource = new InMemorySource<>( + encodedElements(elements, coder), startIndex, endIndex, coder); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(inMemorySource); + List actualElements = new ArrayList<>(); + try (Source.SourceIterator iterator = inMemorySource.iterator()) { + for (long i = inMemorySource.startIndex; iterator.hasNext(); i++) { + Assert.assertEquals( + new ApproximateProgress().setPosition(makeIndexPosition(i)), + sourceProgressToCloudProgress(iterator.getProgress())); + actualElements.add(iterator.next()); + } + } + Assert.assertEquals(expectedElements, actualElements); + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + @Test + public void testReadAllElements() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + null, + null, + Arrays.asList(33, 44, 55, 66, 77, 88), + Arrays.asList(4, 4, 4, 4, 4, 4), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsFromStart() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + 2L, + null, + Arrays.asList(55, 66, 77, 88), + Arrays.asList(4, 4, 4, 4), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsToEnd() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + null, + 3L, + Arrays.asList(33, 44, 55), + Arrays.asList(4, 4, 4), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsFromStartToEnd() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + 2L, + 5L, + Arrays.asList(55, 66, 77), + Arrays.asList(4, 4, 4), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsOffEnd() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + null, + 30L, + Arrays.asList(33, 44, 55, 66, 77, 88), + Arrays.asList(4, 4, 4, 4, 4, 4), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsFromStartPastEnd() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + 20L, + null, + Arrays.asList(), + Arrays.asList(), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadElementsFromStartToEndEmptyRange() throws Exception { + runTestReadInMemorySource(Arrays.asList(33, 44, 55, 66, 77, 88), + 2L, + 2L, + Arrays.asList(), + Arrays.asList(), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadNoElements() throws Exception { + runTestReadInMemorySource(Arrays.asList(), + null, + null, + Arrays.asList(), + Arrays.asList(), + BigEndianIntegerCoder.of()); + } + + @Test + public void testReadNoElementsFromStartToEndEmptyRange() throws Exception { + runTestReadInMemorySource(Arrays.asList(), + 0L, + 0L, + Arrays.asList(), + Arrays.asList(), + BigEndianIntegerCoder.of()); + } + + @Test + public void testUpdatePosition() throws Exception { + List elements = Arrays.asList(33, 44, 55, 66, 77, 88); + final long start = 1L; + final long stop = 3L; + final long end = 4L; + + Coder coder = BigEndianIntegerCoder.of(); + InMemorySource inMemorySource = new InMemorySource<>( + encodedElements(elements, coder), start, end, coder); + + // Illegal proposed stop position. + try (Source.SourceIterator iterator = inMemorySource.iterator()) { + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress(new ApproximateProgress()))); + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress( + new ApproximateProgress().setPosition(makeIndexPosition(null))))); + } + + // Successful update. + try (InMemorySource.InMemorySourceIterator iterator = + (InMemorySource.InMemorySourceIterator) inMemorySource.iterator()) { + Assert.assertEquals( + makeIndexPosition(stop), + sourcePositionToCloudPosition( + iterator.updateStopPosition( + cloudProgressToSourceProgress( + new ApproximateProgress().setPosition(makeIndexPosition(stop)))))); + Assert.assertEquals(stop, iterator.endPosition); + Assert.assertEquals(44, iterator.next().intValue()); + Assert.assertEquals(55, iterator.next().intValue()); + Assert.assertFalse(iterator.hasNext()); + } + + // Proposed stop position is before the current position, no update. + try (InMemorySource.InMemorySourceIterator iterator = + (InMemorySource.InMemorySourceIterator) inMemorySource.iterator()) { + Assert.assertEquals(44, iterator.next().intValue()); + Assert.assertEquals(55, iterator.next().intValue()); + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress( + new ApproximateProgress().setPosition(makeIndexPosition(stop))))); + Assert.assertEquals((int) end, iterator.endPosition); + Assert.assertTrue(iterator.hasNext()); + } + + // Proposed stop position is after the current stop (end) position, no update. + try (InMemorySource.InMemorySourceIterator iterator = + (InMemorySource.InMemorySourceIterator) inMemorySource.iterator()) { + Assert.assertNull( + iterator.updateStopPosition( + cloudProgressToSourceProgress( + new ApproximateProgress().setPosition(makeIndexPosition(end + 1))))); + Assert.assertEquals((int) end, iterator.endPosition); + } + } + + private Position makeIndexPosition(Long index) { + Position position = new Position(); + if (index != null) { + position.setRecordIndex(index); + } + return position; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactoryTest.java new file mode 100644 index 000000000000..fae22797ef89 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/MapTaskExecutorFactoryTest.java @@ -0,0 +1,567 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; + +import com.google.api.services.dataflow.model.FlattenInstruction; +import com.google.api.services.dataflow.model.InstructionInput; +import com.google.api.services.dataflow.model.InstructionOutput; +import com.google.api.services.dataflow.model.MapTask; +import com.google.api.services.dataflow.model.ParDoInstruction; +import com.google.api.services.dataflow.model.ParallelInstruction; +import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction; +import com.google.api.services.dataflow.model.ReadInstruction; +import com.google.api.services.dataflow.model.WriteInstruction; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.worker.SinkFactoryTest.TestSink; +import com.google.cloud.dataflow.sdk.runners.worker.SinkFactoryTest.TestSinkFactory; +import com.google.cloud.dataflow.sdk.runners.worker.SourceFactoryTest.TestSource; +import com.google.cloud.dataflow.sdk.runners.worker.SourceFactoryTest.TestSourceFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.FlattenOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.MapTaskExecutor; +import com.google.cloud.dataflow.sdk.util.common.worker.Operation; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.ReadOperation; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.common.worker.WriteOperation; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for MapTaskExecutorFactory. + */ +@RunWith(JUnit4.class) +public class MapTaskExecutorFactoryTest { + @Test + public void testCreateMapTaskExecutor() throws Exception { + List instructions = + Arrays.asList( + createReadInstruction("Read"), + createParDoInstruction(0, 0, "DoFn1"), + createParDoInstruction(0, 0, "DoFn2"), + createFlattenInstruction(1, 0, 2, 0, "Flatten"), + createWriteInstruction(3, 0, "Write")); + + MapTask mapTask = new MapTask(); + mapTask.setStageName("test"); + mapTask.setInstructions(instructions); + + CounterSet counterSet = null; + try (MapTaskExecutor executor = + MapTaskExecutorFactory.create( + PipelineOptionsFactory.create(), + mapTask, + new BatchModeExecutionContext())) { + + @SuppressWarnings("unchecked") + List operations = (List) executor.operations; + assertThat( + operations, + CoreMatchers.hasItems( + new IsInstanceOf(ReadOperation.class), + new IsInstanceOf(ParDoOperation.class), + new IsInstanceOf(ParDoOperation.class), + new IsInstanceOf(FlattenOperation.class), + new IsInstanceOf(WriteOperation.class))); + counterSet = executor.getOutputCounters(); + } + + assertEquals( + new CounterSet( + Counter.longs("read_output_name-ElementCount", SUM) + .resetToValue(0L), + Counter.longs("read_output_name-MeanByteCount", MEAN) + .resetToValue(0, 0L), + Counter.longs("Read-ByteCount", SUM).resetToValue(0L), + Counter.longs("test-Read-start-msecs", SUM) + .resetToValue(0L), + Counter.longs("test-Read-read-msecs", SUM) + .resetToValue(0L), + Counter.longs("test-Read-process-msecs", SUM) + .resetToValue(0L), + Counter.longs("test-Read-finish-msecs", SUM) + .resetToValue(0L), + Counter.longs("DoFn1_output-ElementCount", SUM) + .resetToValue(0L), + Counter.longs("DoFn1_output-MeanByteCount", MEAN) + .resetToValue(0, 0L), + Counter.longs("test-DoFn1-start-msecs", SUM).resetToValue(0L), + Counter.longs("test-DoFn1-process-msecs", SUM).resetToValue(0L), + Counter.longs("test-DoFn1-finish-msecs", SUM).resetToValue(0L), + Counter.longs("DoFn2_output-ElementCount", SUM) + .resetToValue(0L), + Counter.longs("DoFn2_output-MeanByteCount", MEAN) + .resetToValue(0, 0L), + Counter.longs("test-DoFn2-start-msecs", SUM).resetToValue(0L), + Counter.longs("test-DoFn2-process-msecs", SUM).resetToValue(0L), + Counter.longs("test-DoFn2-finish-msecs", SUM).resetToValue(0L), + Counter.longs("flatten_output_name-ElementCount", SUM) + .resetToValue(0L), + Counter.longs("flatten_output_name-MeanByteCount", MEAN) + .resetToValue(0, 0L), + Counter.longs("test-Flatten-start-msecs", SUM).resetToValue(0L), + Counter.longs("test-Flatten-process-msecs", SUM).resetToValue(0L), + Counter.longs("test-Flatten-finish-msecs", SUM).resetToValue(0L), + Counter.longs("Write-ByteCount", SUM) + .resetToValue(0L), + Counter.longs("test-Write-start-msecs", SUM).resetToValue(0L), + Counter.longs("test-Write-process-msecs", SUM).resetToValue(0L), + Counter.longs("test-Write-finish-msecs", SUM).resetToValue(0L), + Counter.longs("test-other-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-other-msecs")).getAggregate(false))), + counterSet); + } + + @Test + public void testExecutionContextPlumbing() throws Exception { + List instructions = + Arrays.asList( + createReadInstruction("Read"), + createParDoInstruction(0, 0, "DoFn1"), + createParDoInstruction(1, 0, "DoFn2"), + createWriteInstruction(2, 0, "Write")); + + MapTask mapTask = new MapTask(); + mapTask.setInstructions(instructions); + + BatchModeExecutionContext context = new BatchModeExecutionContext(); + + try (MapTaskExecutor executor = + MapTaskExecutorFactory.create( + PipelineOptionsFactory.create(), mapTask, context)) { + executor.execute(); + } + + List stepNames = new ArrayList<>(); + for (ExecutionContext.StepContext stepContext + : context.getAllStepContexts()) { + stepNames.add(stepContext.getStepName()); + } + assertThat(stepNames, CoreMatchers.hasItems("DoFn1", "DoFn2")); + } + + static ParallelInstruction createReadInstruction(String name) { + CloudObject spec = CloudObject.forClass(TestSourceFactory.class); + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(CloudObject.forClass(StringUtf8Coder.class)); + + ReadInstruction readInstruction = new ReadInstruction(); + readInstruction.setSource(cloudSource); + + InstructionOutput output = new InstructionOutput(); + output.setName("read_output_name"); + output.setCodec(CloudObject.forClass(StringUtf8Coder.class)); + + ParallelInstruction instruction = new ParallelInstruction(); + instruction.setSystemName(name); + instruction.setRead(readInstruction); + instruction.setOutputs(Arrays.asList(output)); + + return instruction; + } + + @Test + public void testCreateReadOperation() throws Exception { + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, + counterSet.getAddCounterMutator()); + Operation operation = MapTaskExecutorFactory.createOperation( + PipelineOptionsFactory.create(), + createReadInstruction("Read"), + new BatchModeExecutionContext(), + Collections.emptyList(), + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + assertThat(operation, new IsInstanceOf(ReadOperation.class)); + ReadOperation readOperation = (ReadOperation) operation; + + assertEquals(readOperation.receivers.length, 1); + assertEquals(readOperation.receivers[0].getReceiverCount(), 0); + assertEquals(readOperation.initializationState, + Operation.InitializationState.UNSTARTED); + assertThat(readOperation.source, new IsInstanceOf(TestSource.class)); + + assertEquals( + new CounterSet( + Counter.longs("test-Read-start-msecs", SUM) + .resetToValue(0L), + Counter.longs("read_output_name-MeanByteCount", MEAN) + .resetToValue(0, 0L), + Counter.longs("Read-ByteCount", SUM).resetToValue(0L), + Counter.longs("test-Read-finish-msecs", SUM) + .resetToValue(0L), + Counter.longs("test-Read-read-msecs", SUM), + Counter.longs("test-Read-process-msecs", SUM), + Counter.longs("read_output_name-ElementCount", SUM).resetToValue(0L)), + counterSet); + } + + static ParallelInstruction createWriteInstruction( + int producerIndex, + int producerOutputNum, + String systemName) { + InstructionInput cloudInput = new InstructionInput(); + cloudInput.setProducerInstructionIndex(producerIndex); + cloudInput.setOutputNum(producerOutputNum); + + CloudObject spec = CloudObject.forClass(TestSinkFactory.class); + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(CloudObject.forClass(StringUtf8Coder.class)); + + WriteInstruction writeInstruction = new WriteInstruction(); + writeInstruction.setInput(cloudInput); + writeInstruction.setSink(cloudSink); + + ParallelInstruction instruction = new ParallelInstruction(); + instruction.setWrite(writeInstruction); + instruction.setSystemName(systemName); + + return instruction; + } + + @Test + public void testCreateWriteOperation() throws Exception { + List priorOperations = Arrays.asList(new Operation[]{ + new TestOperation(3), + new TestOperation(5), + new TestOperation(1) }); + + int producerIndex = 1; + int producerOutputNum = 2; + + ParallelInstruction instruction = + createWriteInstruction(producerIndex, producerOutputNum, "WriteOperation"); + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, + counterSet.getAddCounterMutator()); + Operation operation = MapTaskExecutorFactory.createOperation( + PipelineOptionsFactory.create(), + instruction, + new BatchModeExecutionContext(), + priorOperations, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + assertThat(operation, new IsInstanceOf(WriteOperation.class)); + WriteOperation writeOperation = (WriteOperation) operation; + + assertEquals(writeOperation.receivers.length, 0); + assertEquals(writeOperation.initializationState, + Operation.InitializationState.UNSTARTED); + assertThat(writeOperation.sink, + new IsInstanceOf(TestSink.class)); + + assertSame( + writeOperation, + priorOperations.get(producerIndex).receivers[producerOutputNum] + .getOnlyReceiver()); + + assertEquals( + new CounterSet( + Counter.longs("WriteOperation-ByteCount", SUM) + .resetToValue(0L), + Counter.longs("test-WriteOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-WriteOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-WriteOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-finish-msecs")).getAggregate(false))), + counterSet); + } + + static class TestDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { } + } + + static ParallelInstruction createParDoInstruction( + int producerIndex, + int producerOutputNum, + String systemName) { + InstructionInput cloudInput = new InstructionInput(); + cloudInput.setProducerInstructionIndex(producerIndex); + cloudInput.setOutputNum(producerOutputNum); + + TestDoFn fn = new TestDoFn(); + + String serializedFn = + StringUtils.byteArrayToJsonString( + SerializableUtils.serializeToByteArray(fn)); + + CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); + addString(cloudUserFn, PropertyNames.SERIALIZED_FN, serializedFn); + + ParDoInstruction parDoInstruction = new ParDoInstruction(); + parDoInstruction.setInput(cloudInput); + parDoInstruction.setNumOutputs(1); + parDoInstruction.setUserFn(cloudUserFn); + + InstructionOutput output = new InstructionOutput(); + output.setName(systemName + "_output"); + output.setCodec(CloudObject.forClass(StringUtf8Coder.class)); + + ParallelInstruction instruction = new ParallelInstruction(); + instruction.setParDo(parDoInstruction); + instruction.setOutputs(Arrays.asList(output)); + instruction.setSystemName(systemName); + return instruction; + } + + @Test + public void testCreateParDoOperation() throws Exception { + List priorOperations = Arrays.asList(new Operation[]{ + new TestOperation(3), + new TestOperation(5), + new TestOperation(1) }); + + int producerIndex = 1; + int producerOutputNum = 2; + + ParallelInstruction instruction = + createParDoInstruction(producerIndex, producerOutputNum, "DoFn"); + + BatchModeExecutionContext context = new BatchModeExecutionContext(); + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, + counterSet.getAddCounterMutator()); + Operation operation = MapTaskExecutorFactory.createOperation( + PipelineOptionsFactory.create(), + instruction, + context, + priorOperations, + counterPrefix, + counterSet.getAddCounterMutator(), stateSampler); + assertThat(operation, new IsInstanceOf(ParDoOperation.class)); + ParDoOperation parDoOperation = (ParDoOperation) operation; + + assertEquals(parDoOperation.receivers.length, 1); + assertEquals(parDoOperation.receivers[0].getReceiverCount(), 0); + assertEquals(parDoOperation.initializationState, + Operation.InitializationState.UNSTARTED); + assertThat(parDoOperation.fn, + new IsInstanceOf(NormalParDoFn.class)); + NormalParDoFn normalParDoFn = (NormalParDoFn) parDoOperation.fn; + + assertThat(normalParDoFn.fn, + new IsInstanceOf(TestDoFn.class)); + + assertSame( + parDoOperation, + priorOperations.get(producerIndex).receivers[producerOutputNum] + .getOnlyReceiver()); + + assertEquals(context, normalParDoFn.executionContext); + } + + static ParallelInstruction createPartialGroupByKeyInstruction( + int producerIndex, + int producerOutputNum) { + InstructionInput cloudInput = new InstructionInput(); + cloudInput.setProducerInstructionIndex(producerIndex); + cloudInput.setOutputNum(producerOutputNum); + + PartialGroupByKeyInstruction pgbkInstruction = + new PartialGroupByKeyInstruction(); + pgbkInstruction.setInput(cloudInput); + pgbkInstruction.setInputElementCodec( + makeCloudEncoding(FullWindowedValueCoder.class.getName(), + makeCloudEncoding("KvCoder", + makeCloudEncoding("StringUtf8Coder"), + makeCloudEncoding("BigEndianIntegerCoder")), + IntervalWindow.getCoder().asCloudObject())); + + InstructionOutput output = new InstructionOutput(); + output.setName("pgbk_output_name"); + output.setCodec(makeCloudEncoding( + "KvCoder", + makeCloudEncoding("StringUtf8Coder"), + makeCloudEncoding( + "IterableCoder", + makeCloudEncoding("BigEndianIntegerCoder")))); + + ParallelInstruction instruction = new ParallelInstruction(); + instruction.setPartialGroupByKey(pgbkInstruction); + instruction.setOutputs(Arrays.asList(output)); + + return instruction; + } + + @Test + public void testCreatePartialGroupByKeyOperation() throws Exception { + List priorOperations = Arrays.asList(new Operation[]{ + new TestOperation(3), + new TestOperation(5), + new TestOperation(1) }); + + int producerIndex = 1; + int producerOutputNum = 2; + + ParallelInstruction instruction = + createPartialGroupByKeyInstruction(producerIndex, producerOutputNum); + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, + counterSet.getAddCounterMutator()); + Operation operation = MapTaskExecutorFactory.createOperation( + PipelineOptionsFactory.create(), + instruction, + new BatchModeExecutionContext(), + priorOperations, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + assertThat(operation, instanceOf(PartialGroupByKeyOperation.class)); + PartialGroupByKeyOperation pgbkOperation = + (PartialGroupByKeyOperation) operation; + + assertEquals(pgbkOperation.receivers.length, 1); + assertEquals(pgbkOperation.receivers[0].getReceiverCount(), 0); + assertEquals(pgbkOperation.initializationState, + Operation.InitializationState.UNSTARTED); + + assertSame( + pgbkOperation, + priorOperations.get(producerIndex).receivers[producerOutputNum] + .getOnlyReceiver()); + } + + static ParallelInstruction createFlattenInstruction( + int producerIndex1, + int producerOutputNum1, + int producerIndex2, + int producerOutputNum2, + String systemName) { + List cloudInputs = new ArrayList<>(); + + InstructionInput cloudInput1 = new InstructionInput(); + cloudInput1.setProducerInstructionIndex(producerIndex1); + cloudInput1.setOutputNum(producerOutputNum1); + cloudInputs.add(cloudInput1); + + InstructionInput cloudInput2 = new InstructionInput(); + cloudInput2.setProducerInstructionIndex(producerIndex2); + cloudInput2.setOutputNum(producerOutputNum2); + cloudInputs.add(cloudInput2); + + FlattenInstruction flattenInstruction = new FlattenInstruction(); + flattenInstruction.setInputs(cloudInputs); + + InstructionOutput output = new InstructionOutput(); + output.setName("flatten_output_name"); + output.setCodec(makeCloudEncoding(StringUtf8Coder.class.getName())); + + ParallelInstruction instruction = new ParallelInstruction(); + instruction.setFlatten(flattenInstruction); + instruction.setOutputs(Arrays.asList(output)); + instruction.setSystemName(systemName); + + return instruction; + } + + @Test + public void testCreateFlattenOperation() throws Exception { + List priorOperations = Arrays.asList(new Operation[]{ + new TestOperation(3), + new TestOperation(5), + new TestOperation(1) }); + + int producerIndex1 = 1; + int producerOutputNum1 = 2; + int producerIndex2 = 0; + int producerOutputNum2 = 1; + + ParallelInstruction instruction = + createFlattenInstruction(producerIndex1, producerOutputNum1, + producerIndex2, producerOutputNum2, "Flatten"); + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, + counterSet.getAddCounterMutator()); + Operation operation = MapTaskExecutorFactory.createOperation( + PipelineOptionsFactory.create(), + instruction, + new BatchModeExecutionContext(), + priorOperations, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + assertThat(operation, new IsInstanceOf(FlattenOperation.class)); + FlattenOperation flattenOperation = (FlattenOperation) operation; + + assertEquals(flattenOperation.receivers.length, 1); + assertEquals(flattenOperation.receivers[0].getReceiverCount(), 0); + assertEquals(flattenOperation.initializationState, + Operation.InitializationState.UNSTARTED); + + assertSame( + flattenOperation, + priorOperations.get(producerIndex1).receivers[producerOutputNum1] + .getOnlyReceiver()); + assertSame( + flattenOperation, + priorOperations.get(producerIndex2).receivers[producerOutputNum2] + .getOnlyReceiver()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFnTest.java new file mode 100644 index 000000000000..f94ab8339f9d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/NormalParDoFnTest.java @@ -0,0 +1,331 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.Receiver; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for NormalParDoFn. + */ +@RunWith(JUnit4.class) +public class NormalParDoFnTest { + static class TestDoFn extends DoFn { + enum State { UNSTARTED, STARTED, PROCESSING, FINISHED } + State state = State.UNSTARTED; + + List sideOutputTupleTags; + + public TestDoFn(List sideOutputTags) { + sideOutputTupleTags = new ArrayList<>(); + for (String sideOutputTag : sideOutputTags) { + sideOutputTupleTags.add(new TupleTag(sideOutputTag)); + } + } + + @Override + public void startBundle(Context c) { + assertEquals(State.UNSTARTED, state); + state = State.STARTED; + outputToAll(c, "started"); + } + + @Override + public void processElement(ProcessContext c) { + assertThat(state, anyOf(equalTo(State.STARTED), + equalTo(State.PROCESSING))); + state = State.PROCESSING; + outputToAll(c, "processing: " + c.element()); + } + + @Override + public void finishBundle(Context c) { + assertThat(state, anyOf(equalTo(State.STARTED), + equalTo(State.PROCESSING))); + state = State.FINISHED; + outputToAll(c, "finished"); + } + + private void outputToAll(Context c, String value) { + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + } + + static class TestErrorDoFn extends DoFn { + + // Used to test nested stack traces. + private void nestedFunctionBeta(String s) { + throw new RuntimeException(s); + } + + private void nestedFunctionAlpha(String s) { + nestedFunctionBeta(s); + } + + @Override + public void startBundle(Context c) { + nestedFunctionAlpha("test error in initialize"); + } + + @Override + public void processElement(ProcessContext c) { + nestedFunctionBeta("test error in process"); + } + + @Override + public void finishBundle(Context c) { + throw new RuntimeException("test error in finalize"); + } + } + + static class TestReceiver implements Receiver { + List receivedElems = new ArrayList<>(); + + @Override + public void process(Object outputElem) { + receivedElems.add(outputElem); + } + } + + @Test + public void testNormalParDoFn() throws Exception { + List sideOutputTags = Arrays.asList("tag1", "tag2", "tag3"); + + TestDoFn fn = new TestDoFn(sideOutputTags); + TestReceiver receiver = new TestReceiver(); + TestReceiver receiver1 = new TestReceiver(); + TestReceiver receiver2 = new TestReceiver(); + TestReceiver receiver3 = new TestReceiver(); + + PTuple sideInputValues = PTuple.empty(); + + List outputTags = new ArrayList<>(); + outputTags.add("output"); + outputTags.addAll(sideOutputTags); + NormalParDoFn normalParDoFn = + new NormalParDoFn(PipelineOptionsFactory.create(), + fn, sideInputValues, outputTags, "doFn", + new BatchModeExecutionContext(), + (new CounterSet()).getAddCounterMutator()); + + normalParDoFn.startBundle(receiver, receiver1, receiver2, receiver3); + + normalParDoFn.processElement(WindowedValue.valueInGlobalWindow(3)); + normalParDoFn.processElement(WindowedValue.valueInGlobalWindow(42)); + normalParDoFn.processElement(WindowedValue.valueInGlobalWindow(666)); + + normalParDoFn.finishBundle(); + + Object[] expectedReceivedElems = { + WindowedValue.valueInGlobalWindow("started"), + WindowedValue.valueInGlobalWindow("processing: 3"), + WindowedValue.valueInGlobalWindow("processing: 42"), + WindowedValue.valueInGlobalWindow("processing: 666"), + WindowedValue.valueInGlobalWindow("finished"), + }; + assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray()); + + Object[] expectedReceivedElems1 = { + WindowedValue.valueInGlobalWindow("tag1: started"), + WindowedValue.valueInGlobalWindow("tag1: processing: 3"), + WindowedValue.valueInGlobalWindow("tag1: processing: 42"), + WindowedValue.valueInGlobalWindow("tag1: processing: 666"), + WindowedValue.valueInGlobalWindow("tag1: finished"), + }; + assertArrayEquals(expectedReceivedElems1, receiver1.receivedElems.toArray()); + + Object[] expectedReceivedElems2 = { + WindowedValue.valueInGlobalWindow("tag2: started"), + WindowedValue.valueInGlobalWindow("tag2: processing: 3"), + WindowedValue.valueInGlobalWindow("tag2: processing: 42"), + WindowedValue.valueInGlobalWindow("tag2: processing: 666"), + WindowedValue.valueInGlobalWindow("tag2: finished"), + }; + assertArrayEquals(expectedReceivedElems2, receiver2.receivedElems.toArray()); + + Object[] expectedReceivedElems3 = { + WindowedValue.valueInGlobalWindow("tag3: started"), + WindowedValue.valueInGlobalWindow("tag3: processing: 3"), + WindowedValue.valueInGlobalWindow("tag3: processing: 42"), + WindowedValue.valueInGlobalWindow("tag3: processing: 666"), + WindowedValue.valueInGlobalWindow("tag3: finished"), + }; + assertArrayEquals(expectedReceivedElems3, receiver3.receivedElems.toArray()); + } + + @Test + public void testUnexpectedNumberOfReceivers() throws Exception { + TestDoFn fn = new TestDoFn(Collections.emptyList()); + TestReceiver receiver = new TestReceiver(); + + PTuple sideInputValues = PTuple.empty(); + List outputTags = Arrays.asList("output"); + NormalParDoFn normalParDoFn = + new NormalParDoFn(PipelineOptionsFactory.create(), + fn, sideInputValues, outputTags, "doFn", + new BatchModeExecutionContext(), + (new CounterSet()).getAddCounterMutator()); + + try { + normalParDoFn.startBundle(); + fail("should have failed"); + } catch (Throwable exn) { + assertThat(exn.toString(), + containsString("unexpected number of receivers")); + } + try { + normalParDoFn.startBundle(receiver, receiver); + fail("should have failed"); + } catch (Throwable exn) { + assertThat(exn.toString(), + containsString("unexpected number of receivers")); + } + } + + private List stackTraceFrameStrings(Throwable t) { + List stack = new ArrayList<>(); + for (StackTraceElement frame : t.getStackTrace()) { + // Make sure that the frame has the expected name. + stack.add(frame.toString()); + } + return stack; + } + + @Test + public void testErrorPropagation() throws Exception { + TestErrorDoFn fn = new TestErrorDoFn(); + TestReceiver receiver = new TestReceiver(); + + PTuple sideInputValues = PTuple.empty(); + List outputTags = Arrays.asList("output"); + NormalParDoFn normalParDoFn = + new NormalParDoFn(PipelineOptionsFactory.create(), + fn, sideInputValues, outputTags, "doFn", + new BatchModeExecutionContext(), + (new CounterSet()).getAddCounterMutator()); + + try { + normalParDoFn.startBundle(receiver); + fail("should have failed"); + } catch (Exception exn) { + // Because we're calling this from inside the SDK and not from a + // user's program (e.g. through Pipeline.run), the error should + // be thrown as a UserCodeException. The cause of the + // UserCodeError shouldn't contain any of the stack from within + // the SDK, since we don't want to overwhelm users with stack + // frames outside of their control. + assertThat(exn, instanceOf(UserCodeException.class)); + // Stack trace of the cause should contain three frames: + // TestErrorDoFn.nestedFunctionBeta + // TestErrorDoFn.nestedFunctionAlpha + // TestErrorDoFn.startBundle + assertThat(stackTraceFrameStrings(exn.getCause()), contains( + containsString("TestErrorDoFn.nestedFunctionBeta"), + containsString("TestErrorDoFn.nestedFunctionAlpha"), + containsString("TestErrorDoFn.startBundle"))); + assertThat(exn.toString(), + containsString("test error in initialize")); + } + + try { + normalParDoFn.processElement(WindowedValue.valueInGlobalWindow(3)); + fail("should have failed"); + } catch (Exception exn) { + // Exception should be a UserCodeException since we're calling + // from inside the SDK. + assertThat(exn, instanceOf(UserCodeException.class)); + // Stack trace of the cause should contain two frames: + // TestErrorDoFn.nestedFunctionBeta + // TestErrorDoFn.processElement + assertThat(stackTraceFrameStrings(exn.getCause()), contains( + containsString("TestErrorDoFn.nestedFunctionBeta"), + containsString("TestErrorDoFn.processElement"))); + assertThat(exn.toString(), containsString("test error in process")); + } + + try { + normalParDoFn.finishBundle(); + fail("should have failed"); + } catch (Exception exn) { + // Exception should be a UserCodeException since we're calling + // from inside the SDK. + assertThat(exn, instanceOf(UserCodeException.class)); + // Stack trace should only contain a single frame: + // TestErrorDoFn.finishBundle + assertThat(stackTraceFrameStrings(exn.getCause()), contains( + containsString("TestErrorDoFn.finishBundle"))); + assertThat(exn.toString(), containsString("test error in finalize")); + } + } + + @Test + public void testUndeclaredSideOutputs() throws Exception { + TestDoFn fn = new TestDoFn(Arrays.asList("declared", "undecl1", "undecl2", "undecl3")); + CounterSet counters = new CounterSet(); + NormalParDoFn normalParDoFn = + new NormalParDoFn(PipelineOptionsFactory.create(), fn, PTuple.empty(), + Arrays.asList("output", "declared"), "doFn", + new BatchModeExecutionContext(), + counters.getAddCounterMutator()); + + normalParDoFn.startBundle(new TestReceiver(), new TestReceiver()); + normalParDoFn.processElement(WindowedValue.valueInGlobalWindow(5)); + normalParDoFn.finishBundle(); + + assertEquals( + new CounterSet( + Counter.longs("implicit-undecl1-ElementCount", SUM) + .resetToValue(3L), + Counter.longs("implicit-undecl2-ElementCount", SUM) + .resetToValue(3L), + Counter.longs("implicit-undecl3-ElementCount", SUM) + .resetToValue(3L)), + counters); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCodeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCodeTest.java new file mode 100644 index 000000000000..6f467ba1173a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/OrderedCodeTest.java @@ -0,0 +1,504 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.io.BaseEncoding; +import com.google.common.primitives.Bytes; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for OrderedCode. + */ +@RunWith(JUnit4.class) +public class OrderedCodeTest { + @Test + public void testWriteInfinity() { + OrderedCode orderedCode = new OrderedCode(); + try { + orderedCode.readInfinity(); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // expected + } + orderedCode.writeInfinity(); + assertTrue(orderedCode.readInfinity()); + try { + orderedCode.readInfinity(); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // expected + } + } + + @Test + public void testWriteBytes() { + byte[] first = { 'a', 'b', 'c'}; + byte[] second = { 'd', 'e', 'f'}; + byte[] last = { 'x', 'y', 'z'}; + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeBytes(first); + byte[] firstEncoded = orderedCode.getEncodedBytes(); + assertArrayEquals(orderedCode.readBytes(), first); + + orderedCode.writeBytes(first); + orderedCode.writeBytes(second); + orderedCode.writeBytes(last); + byte[] allEncoded = orderedCode.getEncodedBytes(); + assertArrayEquals(orderedCode.readBytes(), first); + assertArrayEquals(orderedCode.readBytes(), second); + assertArrayEquals(orderedCode.readBytes(), last); + + orderedCode = new OrderedCode(firstEncoded); + orderedCode.writeBytes(second); + orderedCode.writeBytes(last); + assertArrayEquals(orderedCode.getEncodedBytes(), allEncoded); + assertArrayEquals(orderedCode.readBytes(), first); + assertArrayEquals(orderedCode.readBytes(), second); + assertArrayEquals(orderedCode.readBytes(), last); + + orderedCode = new OrderedCode(allEncoded); + assertArrayEquals(orderedCode.readBytes(), first); + assertArrayEquals(orderedCode.readBytes(), second); + assertArrayEquals(orderedCode.readBytes(), last); + } + + @Test + public void testWriteNumIncreasing() { + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeNumIncreasing(0); + orderedCode.writeNumIncreasing(1); + orderedCode.writeNumIncreasing(Long.MIN_VALUE); + orderedCode.writeNumIncreasing(Long.MAX_VALUE); + assertEquals(orderedCode.readNumIncreasing(), 0); + assertEquals(orderedCode.readNumIncreasing(), 1); + assertEquals(orderedCode.readNumIncreasing(), Long.MIN_VALUE); + assertEquals(orderedCode.readNumIncreasing(), Long.MAX_VALUE); + } + + /** + * Assert that encoding the specified long via + * {@link OrderedCode#writeSignedNumIncreasing(long)} results in the bytes + * represented by the specified string of hex digits. + * E.g. assertSignedNumIncreasingEncodingEquals("3fbf", -65) asserts that + * -65 is encoded as { (byte) 0x3f, (byte) 0xbf }. + */ + private static void assertSignedNumIncreasingEncodingEquals( + String expectedHexEncoding, long num) { + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeSignedNumIncreasing(num); + assertEquals( + "Unexpected encoding for " + num, + expectedHexEncoding, + BaseEncoding.base16().lowerCase().encode(orderedCode.getEncodedBytes())); + } + + /** + * Assert that encoding various long values via + * {@link OrderedCode#writeSignedNumIncreasing(long)} produces the expected + * bytes. Expected byte sequences were generated via the c++ (authoritative) + * implementation of OrderedCode::WriteSignedNumIncreasing. + */ + @Test + public void testSignedNumIncreasing_write() { + assertSignedNumIncreasingEncodingEquals( + "003f8000000000000000", Long.MIN_VALUE); + assertSignedNumIncreasingEncodingEquals( + "003f8000000000000001", Long.MIN_VALUE + 1); + assertSignedNumIncreasingEncodingEquals( + "077fffffff", Integer.MIN_VALUE - 1L); + assertSignedNumIncreasingEncodingEquals("0780000000", Integer.MIN_VALUE); + assertSignedNumIncreasingEncodingEquals( + "0780000001", Integer.MIN_VALUE + 1); + assertSignedNumIncreasingEncodingEquals("3fbf", -65); + assertSignedNumIncreasingEncodingEquals("40", -64); + assertSignedNumIncreasingEncodingEquals("41", -63); + assertSignedNumIncreasingEncodingEquals("7d", -3); + assertSignedNumIncreasingEncodingEquals("7e", -2); + assertSignedNumIncreasingEncodingEquals("7f", -1); + assertSignedNumIncreasingEncodingEquals("80", 0); + assertSignedNumIncreasingEncodingEquals("81", 1); + assertSignedNumIncreasingEncodingEquals("82", 2); + assertSignedNumIncreasingEncodingEquals("83", 3); + assertSignedNumIncreasingEncodingEquals("bf", 63); + assertSignedNumIncreasingEncodingEquals("c040", 64); + assertSignedNumIncreasingEncodingEquals("c041", 65); + assertSignedNumIncreasingEncodingEquals( + "f87ffffffe", Integer.MAX_VALUE - 1); + assertSignedNumIncreasingEncodingEquals("f87fffffff", Integer.MAX_VALUE); + assertSignedNumIncreasingEncodingEquals( + "f880000000", Integer.MAX_VALUE + 1L); + assertSignedNumIncreasingEncodingEquals( + "ffc07ffffffffffffffe", Long.MAX_VALUE - 1); + assertSignedNumIncreasingEncodingEquals( + "ffc07fffffffffffffff", Long.MAX_VALUE); + } + + /** + * Convert a string of hex digits (e.g. "3fbf") to a byte[] + * (e.g. { (byte) 0x3f, (byte) 0xbf }). + */ + private static byte[] bytesFromHexString(String hexDigits) { + return BaseEncoding.base16().lowerCase().decode(hexDigits); + } + + /** + * Assert that decoding (via {@link OrderedCode#readSignedNumIncreasing()}) + * the bytes represented by the specified string of hex digits results in the + * expected long value. + * E.g. assertDecodedSignedNumIncreasingEquals(-65, "3fbf") asserts that the + * byte array { (byte) 0x3f, (byte) 0xbf } is decoded as -65. + */ + private static void assertDecodedSignedNumIncreasingEquals( + long expectedNum, String encodedHexString) { + OrderedCode orderedCode = + new OrderedCode(bytesFromHexString(encodedHexString)); + assertEquals( + "Unexpected value when decoding 0x" + encodedHexString, + expectedNum, + orderedCode.readSignedNumIncreasing()); + assertFalse( + "Unexpected encoded bytes remain after decoding 0x" + encodedHexString, + orderedCode.hasRemainingEncodedBytes()); + } + + /** + * Assert that decoding various sequences of bytes via + * {@link OrderedCode#readSignedNumIncreasing()} produces the expected long + * value. + * Input byte sequences were generated via the c++ (authoritative) + * implementation of OrderedCode::WriteSignedNumIncreasing. + */ + @Test + public void testSignedNumIncreasing_read() { + assertDecodedSignedNumIncreasingEquals( + Long.MIN_VALUE, "003f8000000000000000"); + assertDecodedSignedNumIncreasingEquals( + Long.MIN_VALUE + 1, "003f8000000000000001"); + assertDecodedSignedNumIncreasingEquals( + Integer.MIN_VALUE - 1L, "077fffffff"); + assertDecodedSignedNumIncreasingEquals(Integer.MIN_VALUE, "0780000000"); + assertDecodedSignedNumIncreasingEquals(Integer.MIN_VALUE + 1, "0780000001"); + assertDecodedSignedNumIncreasingEquals(-65, "3fbf"); + assertDecodedSignedNumIncreasingEquals(-64, "40"); + assertDecodedSignedNumIncreasingEquals(-63, "41"); + assertDecodedSignedNumIncreasingEquals(-3, "7d"); + assertDecodedSignedNumIncreasingEquals(-2, "7e"); + assertDecodedSignedNumIncreasingEquals(-1, "7f"); + assertDecodedSignedNumIncreasingEquals(0, "80"); + assertDecodedSignedNumIncreasingEquals(1, "81"); + assertDecodedSignedNumIncreasingEquals(2, "82"); + assertDecodedSignedNumIncreasingEquals(3, "83"); + assertDecodedSignedNumIncreasingEquals(63, "bf"); + assertDecodedSignedNumIncreasingEquals(64, "c040"); + assertDecodedSignedNumIncreasingEquals(65, "c041"); + assertDecodedSignedNumIncreasingEquals(Integer.MAX_VALUE - 1, "f87ffffffe"); + assertDecodedSignedNumIncreasingEquals(Integer.MAX_VALUE, "f87fffffff"); + assertDecodedSignedNumIncreasingEquals( + Integer.MAX_VALUE + 1L, "f880000000"); + assertDecodedSignedNumIncreasingEquals( + Long.MAX_VALUE - 1, "ffc07ffffffffffffffe"); + assertDecodedSignedNumIncreasingEquals( + Long.MAX_VALUE, "ffc07fffffffffffffff"); + } + + /** + * Assert that encoding (via + * {@link OrderedCode#writeSignedNumIncreasing(long)}) the specified long + * value and then decoding (via {@link OrderedCode#readSignedNumIncreasing()}) + * results in the original value. + */ + private static void assertSignedNumIncreasingWriteAndReadIsLossless( + long num) { + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeSignedNumIncreasing(num); + assertEquals( + "Unexpected result when decoding writeSignedNumIncreasing(" + num + ")", + num, + orderedCode.readSignedNumIncreasing()); + assertFalse("Unexpected remaining encoded bytes after decoding " + num, + orderedCode.hasRemainingEncodedBytes()); + } + + /** + * Assert that for various long values, encoding (via + * {@link OrderedCode#writeSignedNumIncreasing(long)}) and then decoding (via + * {@link OrderedCode#readSignedNumIncreasing()}) results in the original + * value. + */ + @Test + public void testSignedNumIncreasing_writeAndRead() { + assertSignedNumIncreasingWriteAndReadIsLossless(Long.MIN_VALUE); + assertSignedNumIncreasingWriteAndReadIsLossless(Long.MIN_VALUE + 1); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MIN_VALUE - 1L); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MIN_VALUE); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MIN_VALUE + 1); + assertSignedNumIncreasingWriteAndReadIsLossless(-65); + assertSignedNumIncreasingWriteAndReadIsLossless(-64); + assertSignedNumIncreasingWriteAndReadIsLossless(-63); + assertSignedNumIncreasingWriteAndReadIsLossless(-3); + assertSignedNumIncreasingWriteAndReadIsLossless(-2); + assertSignedNumIncreasingWriteAndReadIsLossless(-1); + assertSignedNumIncreasingWriteAndReadIsLossless(0); + assertSignedNumIncreasingWriteAndReadIsLossless(1); + assertSignedNumIncreasingWriteAndReadIsLossless(2); + assertSignedNumIncreasingWriteAndReadIsLossless(3); + assertSignedNumIncreasingWriteAndReadIsLossless(63); + assertSignedNumIncreasingWriteAndReadIsLossless(64); + assertSignedNumIncreasingWriteAndReadIsLossless(65); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MAX_VALUE - 1); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MAX_VALUE); + assertSignedNumIncreasingWriteAndReadIsLossless(Integer.MAX_VALUE + 1L); + assertSignedNumIncreasingWriteAndReadIsLossless(Long.MAX_VALUE - 1); + assertSignedNumIncreasingWriteAndReadIsLossless(Long.MAX_VALUE); + } + + @Test + public void testLog2Floor_Positive() { + OrderedCode orderedCode = new OrderedCode(); + assertEquals(0, orderedCode.log2Floor(1)); + assertEquals(1, orderedCode.log2Floor(2)); + assertEquals(1, orderedCode.log2Floor(3)); + assertEquals(2, orderedCode.log2Floor(4)); + assertEquals(5, orderedCode.log2Floor(63)); + assertEquals(6, orderedCode.log2Floor(64)); + assertEquals(62, orderedCode.log2Floor(Long.MAX_VALUE)); + } + + /** + * OrderedCode.log2Floor(long) is defined to return -1 given an input of zero. + */ + @Test + public void testLog2Floor_zero() { + OrderedCode orderedCode = new OrderedCode(); + assertEquals(-1, orderedCode.log2Floor(0)); + } + + @Test + public void testLog2Floor_negative() { + OrderedCode orderedCode = new OrderedCode(); + try { + orderedCode.log2Floor(-1); + fail("Expected an IllegalArgumentException."); + } catch (IllegalArgumentException expected) { + // Expected! + } + } + + @Test + public void testGetSignedEncodingLength() { + OrderedCode orderedCode = new OrderedCode(); + assertEquals(10, orderedCode.getSignedEncodingLength(Long.MIN_VALUE)); + assertEquals(10, orderedCode.getSignedEncodingLength(~(1L << 62))); + assertEquals(9, orderedCode.getSignedEncodingLength(~(1L << 62) + 1)); + assertEquals(3, orderedCode.getSignedEncodingLength(-8193)); + assertEquals(2, orderedCode.getSignedEncodingLength(-8192)); + assertEquals(2, orderedCode.getSignedEncodingLength(-65)); + assertEquals(1, orderedCode.getSignedEncodingLength(-64)); + assertEquals(1, orderedCode.getSignedEncodingLength(-2)); + assertEquals(1, orderedCode.getSignedEncodingLength(-1)); + assertEquals(1, orderedCode.getSignedEncodingLength(0)); + assertEquals(1, orderedCode.getSignedEncodingLength(1)); + assertEquals(1, orderedCode.getSignedEncodingLength(63)); + assertEquals(2, orderedCode.getSignedEncodingLength(64)); + assertEquals(2, orderedCode.getSignedEncodingLength(8191)); + assertEquals(3, orderedCode.getSignedEncodingLength(8192)); + assertEquals(9, orderedCode.getSignedEncodingLength((1L << 62)) - 1); + assertEquals(10, orderedCode.getSignedEncodingLength(1L << 62)); + assertEquals(10, orderedCode.getSignedEncodingLength(Long.MAX_VALUE)); + } + + @Test + public void testWriteTrailingBytes() { + byte[] escapeChars = new byte[] { OrderedCode.ESCAPE1, + OrderedCode.NULL_CHARACTER, OrderedCode.SEPARATOR, OrderedCode.ESCAPE2, + OrderedCode.INFINITY, OrderedCode.FF_CHARACTER}; + byte[] anotherArray = new byte[] { 'a', 'b', 'c', 'd', 'e' }; + + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeTrailingBytes(escapeChars); + assertArrayEquals(orderedCode.getEncodedBytes(), escapeChars); + assertArrayEquals(orderedCode.readTrailingBytes(), escapeChars); + try { + orderedCode.readInfinity(); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // expected + } + + orderedCode = new OrderedCode(); + orderedCode.writeTrailingBytes(anotherArray); + assertArrayEquals(orderedCode.getEncodedBytes(), anotherArray); + assertArrayEquals(orderedCode.readTrailingBytes(), anotherArray); + } + + @Test + public void testMixedWrite() { + byte[] first = { 'a', 'b', 'c'}; + byte[] second = { 'd', 'e', 'f'}; + byte[] last = { 'x', 'y', 'z'}; + byte[] escapeChars = new byte[] { OrderedCode.ESCAPE1, + OrderedCode.NULL_CHARACTER, OrderedCode.SEPARATOR, OrderedCode.ESCAPE2, + OrderedCode.INFINITY, OrderedCode.FF_CHARACTER}; + + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeBytes(first); + orderedCode.writeBytes(second); + orderedCode.writeBytes(last); + orderedCode.writeInfinity(); + orderedCode.writeNumIncreasing(0); + orderedCode.writeNumIncreasing(1); + orderedCode.writeNumIncreasing(Long.MIN_VALUE); + orderedCode.writeNumIncreasing(Long.MAX_VALUE); + orderedCode.writeSignedNumIncreasing(0); + orderedCode.writeSignedNumIncreasing(1); + orderedCode.writeSignedNumIncreasing(Long.MIN_VALUE); + orderedCode.writeSignedNumIncreasing(Long.MAX_VALUE); + orderedCode.writeTrailingBytes(escapeChars); + byte[] allEncoded = orderedCode.getEncodedBytes(); + assertArrayEquals(orderedCode.readBytes(), first); + assertArrayEquals(orderedCode.readBytes(), second); + assertFalse(orderedCode.readInfinity()); + assertArrayEquals(orderedCode.readBytes(), last); + assertTrue(orderedCode.readInfinity()); + assertEquals(orderedCode.readNumIncreasing(), 0); + assertEquals(orderedCode.readNumIncreasing(), 1); + assertFalse(orderedCode.readInfinity()); + assertEquals(orderedCode.readNumIncreasing(), Long.MIN_VALUE); + assertEquals(orderedCode.readNumIncreasing(), Long.MAX_VALUE); + assertEquals(orderedCode.readSignedNumIncreasing(), 0); + assertEquals(orderedCode.readSignedNumIncreasing(), 1); + assertFalse(orderedCode.readInfinity()); + assertEquals(orderedCode.readSignedNumIncreasing(), Long.MIN_VALUE); + assertEquals(orderedCode.readSignedNumIncreasing(), Long.MAX_VALUE); + assertArrayEquals(orderedCode.getEncodedBytes(), escapeChars); + assertArrayEquals(orderedCode.readTrailingBytes(), escapeChars); + + orderedCode = new OrderedCode(allEncoded); + assertArrayEquals(orderedCode.readBytes(), first); + assertArrayEquals(orderedCode.readBytes(), second); + assertFalse(orderedCode.readInfinity()); + assertArrayEquals(orderedCode.readBytes(), last); + assertTrue(orderedCode.readInfinity()); + assertEquals(orderedCode.readNumIncreasing(), 0); + assertEquals(orderedCode.readNumIncreasing(), 1); + assertFalse(orderedCode.readInfinity()); + assertEquals(orderedCode.readNumIncreasing(), Long.MIN_VALUE); + assertEquals(orderedCode.readNumIncreasing(), Long.MAX_VALUE); + assertEquals(orderedCode.readSignedNumIncreasing(), 0); + assertEquals(orderedCode.readSignedNumIncreasing(), 1); + assertFalse(orderedCode.readInfinity()); + assertEquals(orderedCode.readSignedNumIncreasing(), Long.MIN_VALUE); + assertEquals(orderedCode.readSignedNumIncreasing(), Long.MAX_VALUE); + assertArrayEquals(orderedCode.getEncodedBytes(), escapeChars); + assertArrayEquals(orderedCode.readTrailingBytes(), escapeChars); + } + + @Test + public void testEdgeCases() { + byte[] ffChar = {OrderedCode.FF_CHARACTER}; + byte[] nullChar = {OrderedCode.NULL_CHARACTER}; + + byte[] separatorEncoded = {OrderedCode.ESCAPE1, OrderedCode.SEPARATOR}; + byte[] ffCharEncoded = {OrderedCode.ESCAPE1, OrderedCode.NULL_CHARACTER}; + byte[] nullCharEncoded = {OrderedCode.ESCAPE2, OrderedCode.FF_CHARACTER}; + byte[] infinityEncoded = {OrderedCode.ESCAPE2, OrderedCode.INFINITY}; + + OrderedCode orderedCode = new OrderedCode(); + orderedCode.writeBytes(ffChar); + orderedCode.writeBytes(nullChar); + orderedCode.writeInfinity(); + assertArrayEquals(orderedCode.getEncodedBytes(), + Bytes.concat(ffCharEncoded, separatorEncoded, + nullCharEncoded, separatorEncoded, + infinityEncoded)); + assertArrayEquals(orderedCode.readBytes(), ffChar); + assertArrayEquals(orderedCode.readBytes(), nullChar); + assertTrue(orderedCode.readInfinity()); + + orderedCode = new OrderedCode( + Bytes.concat(ffCharEncoded, separatorEncoded)); + assertArrayEquals(orderedCode.readBytes(), ffChar); + + orderedCode = new OrderedCode( + Bytes.concat(nullCharEncoded, separatorEncoded)); + assertArrayEquals(orderedCode.readBytes(), nullChar); + + byte[] invalidEncodingForRead = {OrderedCode.ESCAPE2, OrderedCode.ESCAPE2, + OrderedCode.ESCAPE1, OrderedCode.SEPARATOR}; + orderedCode = new OrderedCode(invalidEncodingForRead); + try { + orderedCode.readBytes(); + fail("Should have failed."); + } catch (Exception e) { + // Expected + } + assertTrue(orderedCode.hasRemainingEncodedBytes()); + } + + @Test + public void testHasRemainingEncodedBytes() { + byte[] bytes = { 'a', 'b', 'c'}; + long number = 12345; + + // Empty + OrderedCode orderedCode = new OrderedCode(); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + // First and only field of each type. + orderedCode.writeBytes(bytes); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertArrayEquals(orderedCode.readBytes(), bytes); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + orderedCode.writeNumIncreasing(number); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertEquals(orderedCode.readNumIncreasing(), number); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + orderedCode.writeSignedNumIncreasing(number); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertEquals(orderedCode.readSignedNumIncreasing(), number); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + orderedCode.writeInfinity(); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertTrue(orderedCode.readInfinity()); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + orderedCode.writeTrailingBytes(bytes); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertArrayEquals(orderedCode.readTrailingBytes(), bytes); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + + // Two fields of same type. + orderedCode.writeBytes(bytes); + orderedCode.writeBytes(bytes); + assertTrue(orderedCode.hasRemainingEncodedBytes()); + assertArrayEquals(orderedCode.readBytes(), bytes); + assertArrayEquals(orderedCode.readBytes(), bytes); + assertFalse(orderedCode.hasRemainingEncodedBytes()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactoryTest.java new file mode 100644 index 000000000000..05a3864d9bd4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ParDoFnFactoryTest.java @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.api.services.dataflow.model.MultiOutputInfo; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.ParDoFn; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for ParDoFnFactory. + */ +@RunWith(JUnit4.class) +public class ParDoFnFactoryTest { + static class TestDoFn extends DoFn { + final String stringState; + final long longState; + + TestDoFn(String stringState, long longState) { + this.stringState = stringState; + this.longState = longState; + } + + @Override + public void processElement(ProcessContext c) { + throw new RuntimeException("not expecting to call this"); + } + } + + @Test + public void testCreateNormalParDoFn() throws Exception { + String stringState = "some state"; + long longState = 42L; + + TestDoFn fn = new TestDoFn(stringState, longState); + + String serializedFn = + StringUtils.byteArrayToJsonString( + SerializableUtils.serializeToByteArray(fn)); + + CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); + addString(cloudUserFn, "serialized_fn", serializedFn); + + String tag = "output"; + MultiOutputInfo multiOutputInfo = new MultiOutputInfo(); + multiOutputInfo.setTag(tag); + List multiOutputInfos = + Arrays.asList(multiOutputInfo); + + BatchModeExecutionContext context = new BatchModeExecutionContext(); + CounterSet counters = new CounterSet(); + StateSampler stateSampler = new StateSampler( + "test", counters.getAddCounterMutator()); + ParDoFn parDoFn = ParDoFnFactory.create( + PipelineOptionsFactory.create(), + cloudUserFn, "name", null, multiOutputInfos, 1, + context, counters.getAddCounterMutator(), stateSampler); + + Assert.assertThat(parDoFn, new IsInstanceOf(NormalParDoFn.class)); + NormalParDoFn normalParDoFn = (NormalParDoFn) parDoFn; + + DoFn actualDoFn = normalParDoFn.fn; + Assert.assertThat(actualDoFn, new IsInstanceOf(TestDoFn.class)); + TestDoFn actualTestDoFn = (TestDoFn) actualDoFn; + + Assert.assertEquals(stringState, actualTestDoFn.stringState); + Assert.assertEquals(longState, actualTestDoFn.longState); + + Assert.assertEquals(context, normalParDoFn.executionContext); + } + + @Test + public void testCreateUnknownParDoFn() throws Exception { + CloudObject cloudUserFn = CloudObject.forClassName("UnknownKindOfDoFn"); + try { + CounterSet counters = new CounterSet(); + StateSampler stateSampler = new StateSampler( + "test", counters.getAddCounterMutator()); + ParDoFnFactory.create(PipelineOptionsFactory.create(), + cloudUserFn, "name", null, null, 1, + new BatchModeExecutionContext(), + counters.getAddCounterMutator(), + stateSampler); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "unable to create a ParDoFn")); + } + } + + // TODO: Test side inputs. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceTest.java new file mode 100644 index 000000000000..be8c972c5944 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/PartitioningShuffleSourceTest.java @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Tests for PartitioningShuffleSource. + */ +@RunWith(JUnit4.class) +public class PartitioningShuffleSourceTest { + static final List>> NO_KVS = Collections.emptyList(); + + static final Instant timestamp = new Instant(123000); + static final IntervalWindow window = new IntervalWindow(timestamp, timestamp.plus(1000)); + + static final List>> KVS = Arrays.asList( + WindowedValue.of(KV.of(1, "in 1a"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(1, "in 1b"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(2, "in 2a"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(2, "in 2b"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(3, "in 3"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(4, "in 4a"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(4, "in 4b"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(4, "in 4c"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(4, "in 4d"), timestamp, Lists.newArrayList(window)), + WindowedValue.of(KV.of(5, "in 5"), timestamp, Lists.newArrayList(window))); + + void runTestReadShuffleSource(List>> expected) + throws Exception { + Coder>> elemCoder = WindowedValue.getFullCoder( + KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()), + IntervalWindow.getCoder()); + + // Write to shuffle with PARTITION_KEYS ShuffleSink. + ShuffleSink> shuffleSink = new ShuffleSink<>( + PipelineOptionsFactory.create(), + null, ShuffleSink.ShuffleKind.PARTITION_KEYS, + elemCoder); + + TestShuffleWriter shuffleWriter = new TestShuffleWriter(); + + List actualSizes = new ArrayList<>(); + try (Sink.SinkWriter>> shuffleSinkWriter = + shuffleSink.writer(shuffleWriter)) { + for (WindowedValue> value : expected) { + actualSizes.add(shuffleSinkWriter.add(value)); + } + } + List records = shuffleWriter.getRecords(); + Assert.assertEquals(expected.size(), records.size()); + Assert.assertEquals(shuffleWriter.getSizes(), actualSizes); + + // Read from shuffle with PartitioningShuffleSource. + PartitioningShuffleSource shuffleSource = + new PartitioningShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + elemCoder); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(shuffleSource); + + TestShuffleReader shuffleReader = new TestShuffleReader(); + List expectedSizes = new ArrayList<>(); + for (ShuffleEntry record : records) { + expectedSizes.add(record.length()); + shuffleReader.addEntry(record); + } + + List>> actual = new ArrayList<>(); + try (Source.SourceIterator>> iter = + shuffleSource.iterator(shuffleReader)) { + while (iter.hasNext()) { + Assert.assertTrue(iter.hasNext()); + actual.add(iter.next()); + } + Assert.assertFalse(iter.hasNext()); + try { + iter.next(); + Assert.fail("should have failed"); + } catch (NoSuchElementException exn) { + // As expected. + } + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + @Test + public void testReadEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(NO_KVS); + } + + @Test + public void testReadNonEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(KVS); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactoryTest.java new file mode 100644 index 000000000000..4b8901af34b5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkFactoryTest.java @@ -0,0 +1,187 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.encodeBase64String; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for ShuffleSinkFactory. + */ +@RunWith(JUnit4.class) +public class ShuffleSinkFactoryTest { + ShuffleSink runTestCreateShuffleSinkHelper(byte[] shuffleWriterConfig, + String shuffleKind, + CloudObject encoding, + FullWindowedValueCoder coder) + throws Exception { + CloudObject spec = CloudObject.forClassName("ShuffleSink"); + addString(spec, "shuffle_writer_config", encodeBase64String(shuffleWriterConfig)); + addString(spec, "shuffle_kind", shuffleKind); + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(encoding); + + Sink sink = SinkFactory.create(PipelineOptionsFactory.create(), + cloudSink, + new BatchModeExecutionContext()); + Assert.assertThat(sink, new IsInstanceOf(ShuffleSink.class)); + ShuffleSink shuffleSink = (ShuffleSink) sink; + Assert.assertArrayEquals(shuffleWriterConfig, + shuffleSink.shuffleWriterConfig); + Assert.assertEquals(coder, shuffleSink.windowedElemCoder); + return shuffleSink; + } + + void runTestCreateUngroupingShuffleSink(byte[] shuffleWriterConfig, + CloudObject encoding, + FullWindowedValueCoder coder) + throws Exception { + ShuffleSink shuffleSink = runTestCreateShuffleSinkHelper( + shuffleWriterConfig, "ungrouped", encoding, coder); + Assert.assertEquals(ShuffleSink.ShuffleKind.UNGROUPED, + shuffleSink.shuffleKind); + Assert.assertFalse(shuffleSink.shardByKey); + Assert.assertFalse(shuffleSink.groupValues); + Assert.assertFalse(shuffleSink.sortValues); + Assert.assertNull(shuffleSink.keyCoder); + Assert.assertNull(shuffleSink.valueCoder); + Assert.assertNull(shuffleSink.sortKeyCoder); + Assert.assertNull(shuffleSink.sortValueCoder); + } + + void runTestCreatePartitioningShuffleSink(byte[] shuffleWriterConfig, + Coder keyCoder, + Coder valueCoder) + throws Exception { + FullWindowedValueCoder coder = (FullWindowedValueCoder) WindowedValue.getFullCoder( + KvCoder.of(keyCoder, valueCoder), IntervalWindow.getCoder()); + ShuffleSink shuffleSink = runTestCreateShuffleSinkHelper( + shuffleWriterConfig, "partition_keys", coder.asCloudObject(), coder); + Assert.assertEquals(ShuffleSink.ShuffleKind.PARTITION_KEYS, + shuffleSink.shuffleKind); + Assert.assertTrue(shuffleSink.shardByKey); + Assert.assertFalse(shuffleSink.groupValues); + Assert.assertFalse(shuffleSink.sortValues); + Assert.assertEquals(keyCoder, shuffleSink.keyCoder); + Assert.assertEquals(valueCoder, shuffleSink.valueCoder); + Assert.assertEquals(FullWindowedValueCoder.of(valueCoder, + IntervalWindow.getCoder()), + shuffleSink.windowedValueCoder); + Assert.assertNull(shuffleSink.sortKeyCoder); + Assert.assertNull(shuffleSink.sortValueCoder); + } + + void runTestCreateGroupingShuffleSink(byte[] shuffleWriterConfig, + Coder keyCoder, + Coder valueCoder) + throws Exception { + FullWindowedValueCoder coder = (FullWindowedValueCoder) WindowedValue.getFullCoder( + KvCoder.of(keyCoder, valueCoder), IntervalWindow.getCoder()); + ShuffleSink shuffleSink = runTestCreateShuffleSinkHelper( + shuffleWriterConfig, "group_keys", coder.asCloudObject(), coder); + Assert.assertEquals(ShuffleSink.ShuffleKind.GROUP_KEYS, + shuffleSink.shuffleKind); + Assert.assertTrue(shuffleSink.shardByKey); + Assert.assertTrue(shuffleSink.groupValues); + Assert.assertFalse(shuffleSink.sortValues); + Assert.assertEquals(keyCoder, shuffleSink.keyCoder); + Assert.assertEquals(valueCoder, shuffleSink.valueCoder); + Assert.assertNull(shuffleSink.windowedValueCoder); + Assert.assertNull(shuffleSink.sortKeyCoder); + Assert.assertNull(shuffleSink.sortValueCoder); + } + + void runTestCreateGroupingSortingShuffleSink(byte[] shuffleWriterConfig, + Coder keyCoder, + Coder sortKeyCoder, + Coder sortValueCoder) + throws Exception { + FullWindowedValueCoder coder = (FullWindowedValueCoder) WindowedValue.getFullCoder( + KvCoder.of(keyCoder, KvCoder.of(sortKeyCoder, sortValueCoder)), + IntervalWindow.getCoder()); + ShuffleSink shuffleSink = runTestCreateShuffleSinkHelper( + shuffleWriterConfig, "group_keys_and_sort_values", coder.asCloudObject(), coder); + Assert.assertEquals(ShuffleSink.ShuffleKind.GROUP_KEYS_AND_SORT_VALUES, + shuffleSink.shuffleKind); + Assert.assertTrue(shuffleSink.shardByKey); + Assert.assertTrue(shuffleSink.groupValues); + Assert.assertTrue(shuffleSink.sortValues); + Assert.assertEquals(keyCoder, shuffleSink.keyCoder); + Assert.assertEquals(KvCoder.of(sortKeyCoder, sortValueCoder), + shuffleSink.valueCoder); + Assert.assertEquals(sortKeyCoder, shuffleSink.sortKeyCoder); + Assert.assertEquals(sortValueCoder, shuffleSink.sortValueCoder); + Assert.assertNull(shuffleSink.windowedValueCoder); + } + + @Test + public void testCreateUngroupingShuffleSink() throws Exception { + FullWindowedValueCoder coder = (FullWindowedValueCoder) WindowedValue.getFullCoder( + StringUtf8Coder.of(), IntervalWindow.getCoder()); + runTestCreateUngroupingShuffleSink( + new byte[]{(byte) 0xE1}, + coder.asCloudObject(), + coder); + } + + @Test + public void testCreatePartitionShuffleSink() throws Exception { + runTestCreatePartitioningShuffleSink( + new byte[]{(byte) 0xE2}, + BigEndianIntegerCoder.of(), + StringUtf8Coder.of()); + } + + @Test + public void testCreateGroupingShuffleSink() throws Exception { + runTestCreateGroupingShuffleSink( + new byte[]{(byte) 0xE2}, + BigEndianIntegerCoder.of(), + WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder())); + } + + @Test + public void testCreateGroupingSortingShuffleSink() throws Exception { + runTestCreateGroupingSortingShuffleSink( + new byte[]{(byte) 0xE3}, + BigEndianIntegerCoder.of(), + StringUtf8Coder.of(), + VoidCoder.of()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkTest.java new file mode 100644 index 000000000000..3e390b8966af --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSinkTest.java @@ -0,0 +1,236 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink.SinkWriter; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for ShuffleSink. + */ +@RunWith(JUnit4.class) +public class ShuffleSinkTest { + static final List> NO_KVS = Collections.emptyList(); + + static final List> KVS = Arrays.asList( + KV.of(1, "in 1a"), + KV.of(1, "in 1b"), + KV.of(2, "in 2a"), + KV.of(2, "in 2b"), + KV.of(3, "in 3"), + KV.of(4, "in 4a"), + KV.of(4, "in 4b"), + KV.of(4, "in 4c"), + KV.of(4, "in 4d"), + KV.of(5, "in 5")); + + static final List>> NO_SORTING_KVS = + Collections.emptyList(); + + static final List>> SORTING_KVS = + Arrays.asList( + KV.of(1, KV.of("in 1a", 3)), + KV.of(1, KV.of("in 1b", 9)), + KV.of(2, KV.of("in 2a", 2)), + KV.of(2, KV.of("in 2b", 77)), + KV.of(3, KV.of("in 3", 33)), + KV.of(4, KV.of("in 4a", -123)), + KV.of(4, KV.of("in 4b", 0)), + KV.of(4, KV.of("in 4c", -1)), + KV.of(4, KV.of("in 4d", 1)), + KV.of(5, KV.of("in 5", 666))); + + static final Instant timestamp = new Instant(123000); + static final IntervalWindow window = new IntervalWindow(timestamp, timestamp.plus(1000)); + + void runTestWriteUngroupingShuffleSink(List expected) + throws Exception { + Coder> windowedValueCoder = + WindowedValue.getFullCoder(BigEndianIntegerCoder.of(), new GlobalWindow().windowCoder()); + ShuffleSink shuffleSink = new ShuffleSink<>( + PipelineOptionsFactory.create(), + null, ShuffleSink.ShuffleKind.UNGROUPED, + windowedValueCoder); + + TestShuffleWriter shuffleWriter = new TestShuffleWriter(); + List actualSizes = new ArrayList<>(); + try (Sink.SinkWriter> shuffleSinkWriter = + shuffleSink.writer(shuffleWriter)) { + for (Integer value : expected) { + actualSizes.add(shuffleSinkWriter.add(WindowedValue.valueInGlobalWindow(value))); + } + } + + List records = shuffleWriter.getRecords(); + + List actual = new ArrayList<>(); + for (ShuffleEntry record : records) { + // Ignore the key. + byte[] valueBytes = record.getValue(); + WindowedValue value = CoderUtils.decodeFromByteArray(windowedValueCoder, valueBytes); + Assert.assertEquals(Lists.newArrayList(GlobalWindow.Window.INSTANCE), value.getWindows()); + actual.add(value.getValue()); + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(shuffleWriter.getSizes(), actualSizes); + } + + void runTestWriteGroupingShuffleSink( + List> expected) + throws Exception { + ShuffleSink> shuffleSink = new ShuffleSink<>( + PipelineOptionsFactory.create(), + null, ShuffleSink.ShuffleKind.GROUP_KEYS, + WindowedValue.getFullCoder( + KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()), + IntervalWindow.getCoder())); + + TestShuffleWriter shuffleWriter = new TestShuffleWriter(); + List actualSizes = new ArrayList<>(); + try (SinkWriter>> shuffleSinkWriter = + shuffleSink.writer(shuffleWriter)) { + for (KV kv : expected) { + actualSizes.add(shuffleSinkWriter.add( + WindowedValue.of(KV.of(kv.getKey(), kv.getValue()), + timestamp, + Lists.newArrayList(window)))); + } + } + + List records = shuffleWriter.getRecords(); + + List> actual = new ArrayList<>(); + for (ShuffleEntry record : records) { + byte[] keyBytes = record.getKey(); + byte[] valueBytes = record.getValue(); + Assert.assertEquals(timestamp, + CoderUtils.decodeFromByteArray(InstantCoder.of(), record.getSecondaryKey())); + + Integer key = + CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), + keyBytes); + String valueElem = CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), valueBytes); + + actual.add(KV.of(key, valueElem)); + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(shuffleWriter.getSizes(), actualSizes); + } + + void runTestWriteGroupingSortingShuffleSink( + List>> expected) + throws Exception { + ShuffleSink>> shuffleSink = + new ShuffleSink<>( + PipelineOptionsFactory.create(), + null, + ShuffleSink.ShuffleKind.GROUP_KEYS_AND_SORT_VALUES, + WindowedValue.getFullCoder( + KvCoder.of(BigEndianIntegerCoder.of(), + KvCoder.of(StringUtf8Coder.of(), + BigEndianIntegerCoder.of())), + new GlobalWindow().windowCoder())); + + TestShuffleWriter shuffleWriter = new TestShuffleWriter(); + List actualSizes = new ArrayList<>(); + try (Sink.SinkWriter>>> shuffleSinkWriter = + shuffleSink.writer(shuffleWriter)) { + for (KV> kv : expected) { + actualSizes.add(shuffleSinkWriter.add(WindowedValue.valueInGlobalWindow(kv))); + } + } + + List records = shuffleWriter.getRecords(); + + List>> actual = new ArrayList<>(); + for (ShuffleEntry record : records) { + byte[] keyBytes = record.getKey(); + byte[] valueBytes = record.getValue(); + byte[] sortKeyBytes = record.getSecondaryKey(); + + Integer key = + CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), + keyBytes); + String sortKey = + CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), + sortKeyBytes); + Integer sortValue = CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), valueBytes); + + actual.add(KV.of(key, KV.of(sortKey, sortValue))); + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(shuffleWriter.getSizes(), actualSizes); + } + + @Test + public void testWriteEmptyUngroupingShuffleSink() throws Exception { + runTestWriteUngroupingShuffleSink(TestUtils.NO_INTS); + } + + @Test + public void testWriteNonEmptyUngroupingShuffleSink() throws Exception { + runTestWriteUngroupingShuffleSink(TestUtils.INTS); + } + + @Test + public void testWriteEmptyGroupingShuffleSink() throws Exception { + runTestWriteGroupingShuffleSink(NO_KVS); + } + + @Test + public void testWriteNonEmptyGroupingShuffleSink() throws Exception { + runTestWriteGroupingShuffleSink(KVS); + } + + @Test + public void testWriteEmptyGroupingSortingShuffleSink() throws Exception { + runTestWriteGroupingSortingShuffleSink(NO_SORTING_KVS); + } + + @Test + public void testWriteNonEmptyGroupingSortingShuffleSink() throws Exception { + runTestWriteGroupingSortingShuffleSink(SORTING_KVS); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSourceFactoryTest.java new file mode 100644 index 000000000000..75fc7479687e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/ShuffleSourceFactoryTest.java @@ -0,0 +1,230 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.api.client.util.Base64.encodeBase64String; +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import javax.annotation.Nullable; + +/** + * Tests for UngroupedShuffleSourceFactory, GroupingShuffleSourceFactory, + * and PartitioningShuffleSourceFactory. + */ +@RunWith(JUnit4.class) +public class ShuffleSourceFactoryTest { + + T runTestCreateShuffleSource(byte[] shuffleReaderConfig, + @Nullable String start, + @Nullable String end, + CloudObject encoding, + BatchModeExecutionContext context, + Class shuffleSourceClass) + throws Exception { + CloudObject spec = CloudObject.forClassName(shuffleSourceClass.getSimpleName()); + addString(spec, "shuffle_reader_config", encodeBase64String(shuffleReaderConfig)); + if (start != null) { + addString(spec, "start_shuffle_position", start); + } + if (end != null) { + addString(spec, "end_shuffle_position", end); + } + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(encoding); + + Source source = SourceFactory.create( + PipelineOptionsFactory.create(), cloudSource, context); + Assert.assertThat(source, new IsInstanceOf(shuffleSourceClass)); + T shuffleSource = (T) source; + return shuffleSource; + } + + void runTestCreateUngroupedShuffleSource(byte[] shuffleReaderConfig, + @Nullable String start, + @Nullable String end, + CloudObject encoding, + Coder coder) throws Exception { + UngroupedShuffleSource shuffleSource = + runTestCreateShuffleSource(shuffleReaderConfig, + start, + end, + encoding, + new BatchModeExecutionContext(), + UngroupedShuffleSource.class); + Assert.assertArrayEquals(shuffleReaderConfig, + shuffleSource.shuffleReaderConfig); + Assert.assertEquals(start, shuffleSource.startShufflePosition); + Assert.assertEquals(end, shuffleSource.stopShufflePosition); + + Assert.assertEquals(coder, shuffleSource.coder); + } + + void runTestCreateGroupingShuffleSource(byte[] shuffleReaderConfig, + @Nullable String start, + @Nullable String end, + CloudObject encoding, + Coder keyCoder, + Coder valueCoder) throws Exception { + BatchModeExecutionContext context = new BatchModeExecutionContext(); + GroupingShuffleSource shuffleSource = + runTestCreateShuffleSource(shuffleReaderConfig, + start, + end, + encoding, + context, + GroupingShuffleSource.class); + Assert.assertArrayEquals(shuffleReaderConfig, + shuffleSource.shuffleReaderConfig); + Assert.assertEquals(start, shuffleSource.startShufflePosition); + Assert.assertEquals(end, shuffleSource.stopShufflePosition); + + Assert.assertEquals(keyCoder, shuffleSource.keyCoder); + Assert.assertEquals(valueCoder, shuffleSource.valueCoder); + Assert.assertEquals(context, shuffleSource.executionContext); + } + + void runTestCreatePartitioningShuffleSource(byte[] shuffleReaderConfig, + @Nullable String start, + @Nullable String end, + CloudObject encoding, + Coder keyCoder, + Coder windowedValueCoder) throws Exception { + PartitioningShuffleSource shuffleSource = + runTestCreateShuffleSource(shuffleReaderConfig, + start, + end, + encoding, + new BatchModeExecutionContext(), + PartitioningShuffleSource.class); + Assert.assertArrayEquals(shuffleReaderConfig, + shuffleSource.shuffleReaderConfig); + Assert.assertEquals(start, shuffleSource.startShufflePosition); + Assert.assertEquals(end, shuffleSource.stopShufflePosition); + + Assert.assertEquals(keyCoder, shuffleSource.keyCoder); + Assert.assertEquals(windowedValueCoder, shuffleSource.windowedValueCoder); + } + + @Test + public void testCreatePlainUngroupedShuffleSource() throws Exception { + runTestCreateUngroupedShuffleSource( + new byte[]{(byte) 0xE1}, null, null, + makeCloudEncoding("StringUtf8Coder"), + StringUtf8Coder.of()); + } + + @Test + public void testCreateRichUngroupedShuffleSource() throws Exception { + runTestCreateUngroupedShuffleSource( + new byte[]{(byte) 0xE2}, "aaa", "zzz", + makeCloudEncoding("BigEndianIntegerCoder"), + BigEndianIntegerCoder.of()); + } + + @Test + public void testCreatePlainGroupingShuffleSource() throws Exception { + runTestCreateGroupingShuffleSource( + new byte[]{(byte) 0xE1}, null, null, + makeCloudEncoding( + FullWindowedValueCoder.class.getName(), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("BigEndianIntegerCoder"), + makeCloudEncoding( + "IterableCoder", + makeCloudEncoding("StringUtf8Coder"))), + IntervalWindow.getCoder().asCloudObject()), + BigEndianIntegerCoder.of(), + StringUtf8Coder.of()); + } + + @Test + public void testCreateRichGroupingShuffleSource() throws Exception { + runTestCreateGroupingShuffleSource( + new byte[]{(byte) 0xE2}, "aaa", "zzz", + makeCloudEncoding( + FullWindowedValueCoder.class.getName(), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("BigEndianIntegerCoder"), + makeCloudEncoding( + "IterableCoder", + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("StringUtf8Coder"), + makeCloudEncoding("VoidCoder")))), + IntervalWindow.getCoder().asCloudObject()), + BigEndianIntegerCoder.of(), + KvCoder.of(StringUtf8Coder.of(), VoidCoder.of())); + } + + @Test + public void testCreatePlainPartitioningShuffleSource() throws Exception { + runTestCreatePartitioningShuffleSource( + new byte[]{(byte) 0xE1}, null, null, + makeCloudEncoding( + FullWindowedValueCoder.class.getName(), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("BigEndianIntegerCoder"), + makeCloudEncoding("StringUtf8Coder")), + IntervalWindow.getCoder().asCloudObject()), + BigEndianIntegerCoder.of(), + FullWindowedValueCoder.of(StringUtf8Coder.of(), IntervalWindow.getCoder())); + } + + @Test + public void testCreateRichPartitioningShuffleSource() throws Exception { + runTestCreatePartitioningShuffleSource( + new byte[]{(byte) 0xE2}, "aaa", "zzz", + makeCloudEncoding( + FullWindowedValueCoder.class.getName(), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("BigEndianIntegerCoder"), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("StringUtf8Coder"), + makeCloudEncoding("VoidCoder"))), + IntervalWindow.getCoder().asCloudObject()), + BigEndianIntegerCoder.of(), + FullWindowedValueCoder.of(KvCoder.of(StringUtf8Coder.of(), VoidCoder.of()), + IntervalWindow.getCoder())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtilsTest.java new file mode 100644 index 000000000000..ea879335ec02 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SideInputUtilsTest.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.api.services.dataflow.model.SideInputInfo; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for SideInputUtils. + */ +@RunWith(JUnit4.class) +public class SideInputUtilsTest { + SideInputInfo createSingletonSideInputInfo( + com.google.api.services.dataflow.model.Source sideInputSource) { + SideInputInfo sideInputInfo = new SideInputInfo(); + sideInputInfo.setSources(Arrays.asList(sideInputSource)); + sideInputInfo.setKind(CloudObject.forClassName("singleton")); + return sideInputInfo; + } + + SideInputInfo createCollectionSideInputInfo( + com.google.api.services.dataflow.model.Source... sideInputSources) { + SideInputInfo sideInputInfo = new SideInputInfo(); + sideInputInfo.setSources(Arrays.asList(sideInputSources)); + sideInputInfo.setKind(CloudObject.forClassName("collection")); + return sideInputInfo; + } + + com.google.api.services.dataflow.model.Source createSideInputSource(Integer... ints) + throws Exception { + return InMemorySourceFactoryTest.createInMemoryCloudSource( + Arrays.asList(ints), + null, null, + BigEndianIntegerCoder.of()); + } + + void assertThatContains(Object actual, Object... expected) { + assertThat(actual, instanceOf(Iterable.class)); + Iterable iter = (Iterable) actual; + if (expected.length == 0) { + assertThat(iter, is(emptyIterable())); + } else { + assertThat(iter, contains(expected)); + } + } + + @Test + public void testReadSingletonSideInput() throws Exception { + SideInputInfo sideInputInfo = + createSingletonSideInputInfo(createSideInputSource(42)); + + assertEquals(42, + SideInputUtils.readSideInput(PipelineOptionsFactory.create(), + sideInputInfo, + new BatchModeExecutionContext())); + } + + @Test + public void testReadEmptyCollectionSideInput() throws Exception { + SideInputInfo sideInputInfo = + createCollectionSideInputInfo(createSideInputSource()); + + assertThatContains( + SideInputUtils.readSideInput(PipelineOptionsFactory.create(), + sideInputInfo, + new BatchModeExecutionContext())); + } + + @Test + public void testReadCollectionSideInput() throws Exception { + SideInputInfo sideInputInfo = + createCollectionSideInputInfo(createSideInputSource(3, 4, 5, 6)); + + assertThatContains( + SideInputUtils.readSideInput(PipelineOptionsFactory.create(), + sideInputInfo, + new BatchModeExecutionContext()), + 3, 4, 5, 6); + } + + @Test + public void testReadCollectionShardedSideInput() throws Exception { + SideInputInfo sideInputInfo = + createCollectionSideInputInfo( + createSideInputSource(3), + createSideInputSource(), + createSideInputSource(4, 5), + createSideInputSource(6), + createSideInputSource()); + + assertThatContains( + SideInputUtils.readSideInput(PipelineOptionsFactory.create(), + sideInputInfo, + new BatchModeExecutionContext()), + 3, 4, 5, 6); + } + + @Test + public void testReadSingletonSideInputValue() throws Exception { + CloudObject sideInputKind = CloudObject.forClassName("singleton"); + Object elem = "hi"; + List elems = Arrays.asList(elem); + assertEquals(elem, + SideInputUtils.readSideInputValue(sideInputKind, elems)); + } + + @Test + public void testReadCollectionSideInputValue() throws Exception { + CloudObject sideInputKind = CloudObject.forClassName("collection"); + List elems = Arrays.asList("hi", "there", "bob"); + assertEquals(elems, + SideInputUtils.readSideInputValue(sideInputKind, elems)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactoryTest.java new file mode 100644 index 000000000000..66e72545cb71 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SinkFactoryTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for SinkFactory. + */ +@RunWith(JUnit4.class) +public class SinkFactoryTest { + static class TestSinkFactory { + public static TestSink create(PipelineOptions options, + CloudObject o, + Coder coder, + ExecutionContext executionContext) { + return new TestSink(); + } + } + + static class TestSink extends Sink { + @Override + public SinkWriter writer() { + return new TestSinkWriter(); + } + + /** A sink writer that drops its input values, for testing. */ + class TestSinkWriter implements SinkWriter { + @Override + public long add(Integer outputElem) { + return 4; + } + + @Override + public void close() { + } + } + } + + @Test + public void testCreatePredefinedSink() throws Exception { + CloudObject spec = CloudObject.forClassName("TextSink"); + addString(spec, "filename", "/path/to/file.txt"); + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(makeCloudEncoding("StringUtf8Coder")); + + Sink sink = SinkFactory.create(PipelineOptionsFactory.create(), + cloudSink, + new BatchModeExecutionContext()); + Assert.assertThat(sink, new IsInstanceOf(TextSink.class)); + } + + @Test + public void testCreateUserDefinedSink() throws Exception { + CloudObject spec = CloudObject.forClass(TestSinkFactory.class); + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(makeCloudEncoding("BigEndianIntegerCoder")); + + Sink sink = SinkFactory.create(PipelineOptionsFactory.create(), + cloudSink, + new BatchModeExecutionContext()); + Assert.assertThat(sink, new IsInstanceOf(TestSink.class)); + } + + @Test + public void testCreateUnknownSink() throws Exception { + CloudObject spec = CloudObject.forClassName("UnknownSink"); + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(makeCloudEncoding("StringUtf8Coder")); + try { + SinkFactory.create(PipelineOptionsFactory.create(), + cloudSink, + new BatchModeExecutionContext()); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "unable to create a sink")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactoryTest.java new file mode 100644 index 000000000000..4b4665b55869 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/SourceFactoryTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.NoSuchElementException; + +/** + * Tests for SourceFactory. + */ +@RunWith(JUnit4.class) +public class SourceFactoryTest { + static class TestSourceFactory { + public static TestSource create(PipelineOptions options, + CloudObject o, + Coder coder, + ExecutionContext executionContext) { + return new TestSource(); + } + } + + static class TestSource extends Source { + @Override + public SourceIterator iterator() { + return new TestSourceIterator(); + } + + /** A source iterator that produces no values, for testing. */ + class TestSourceIterator extends AbstractSourceIterator { + @Override + public boolean hasNext() { return false; } + + @Override + public Integer next() { + throw new NoSuchElementException(); + } + + @Override + public void close() { + } + } + } + + @Test + public void testCreatePredefinedSource() throws Exception { + CloudObject spec = CloudObject.forClassName("TextSource"); + addString(spec, "filename", "/path/to/file.txt"); + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(makeCloudEncoding("StringUtf8Coder")); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + Assert.assertThat(source, new IsInstanceOf(TextSource.class)); + } + + @Test + public void testCreateUserDefinedSource() throws Exception { + CloudObject spec = CloudObject.forClass(TestSourceFactory.class); + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(makeCloudEncoding("BigEndianIntegerCoder")); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + Assert.assertThat(source, new IsInstanceOf(TestSource.class)); + } + + @Test + public void testCreateUnknownSource() throws Exception { + CloudObject spec = CloudObject.forClassName("UnknownSource"); + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(makeCloudEncoding("StringUtf8Coder")); + try { + SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "unable to create a source")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReader.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReader.java new file mode 100644 index 000000000000..4d5e85881be9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReader.java @@ -0,0 +1,177 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.Reiterator; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntryReader; +import com.google.cloud.dataflow.sdk.util.common.worker.ShufflePosition; +// TODO: Decide how we want to handle this Guava dependency. +import com.google.common.primitives.UnsignedBytes; + +import org.junit.Assert; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.NavigableMap; +import java.util.NoSuchElementException; +import java.util.TreeMap; + +/** + * A fake implementation of a ShuffleEntryReader, for testing. + */ +public class TestShuffleReader implements ShuffleEntryReader { + static final Comparator SHUFFLE_KEY_COMPARATOR = + UnsignedBytes.lexicographicalComparator(); + final NavigableMap> records; + + public TestShuffleReader(NavigableMap> records) { + this.records = records; + } + + public TestShuffleReader() { + this(new TreeMap>(SHUFFLE_KEY_COMPARATOR)); + } + + public void addEntry(String key, String value) { + addEntry(key.getBytes(), value.getBytes()); + } + + public void addEntry(byte[] key, byte[] value) { + addEntry(new ShuffleEntry(key, null, value)); + } + + public void addEntry(ShuffleEntry entry) { + List values = records.get(entry.getKey()); + if (values == null) { + values = new ArrayList<>(); + records.put(entry.getKey(), values); + } + values.add(entry); + } + + public Iterator read() { + return read((byte[]) null, (byte[]) null); + } + + @Override + public Reiterator read(ShufflePosition startPosition, + ShufflePosition endPosition) { + return read(ByteArrayShufflePosition.getPosition(startPosition), + ByteArrayShufflePosition.getPosition(endPosition)); + } + + public Reiterator read(String startKey, String endKey) { + return read(startKey == null ? null : startKey.getBytes(), + endKey == null ? null : endKey.getBytes()); + } + + public Reiteratorread(byte[] startKey, byte[] endKey) { + return new ShuffleReaderIterator(startKey, endKey); + } + + class ShuffleReaderIterator implements Reiterator { + final Iterator>> recordsIter; + final byte[] startKey; + final byte[] endKey; + byte[] currentKey; + Map.Entry> currentRecord; + ListIterator currentValuesIter; + + public ShuffleReaderIterator(byte[] startKey, byte[] endKey) { + this.recordsIter = records.entrySet().iterator(); + this.startKey = startKey; + this.endKey = endKey; + advanceKey(); + } + + private ShuffleReaderIterator(ShuffleReaderIterator it) { + if (it.currentKey != null) { + this.recordsIter = + records.tailMap(it.currentKey, false).entrySet().iterator(); + } else { + this.recordsIter = null; + } + this.startKey = it.startKey; + this.endKey = it.endKey; + this.currentKey = it.currentKey; + this.currentRecord = it.currentRecord; + if (it.currentValuesIter != null) { + this.currentValuesIter = + it.currentRecord.getValue().listIterator( + it.currentValuesIter.nextIndex()); + } else { + this.currentValuesIter = null; + } + } + + @Override + public boolean hasNext() { + return currentKey != null; + } + + @Override + public ShuffleEntry next() { + if (currentKey == null) { + throw new NoSuchElementException(); + } + ShuffleEntry resultValue = currentValuesIter.next(); + Assert.assertTrue(Arrays.equals(currentKey, resultValue.getKey())); + if (!currentValuesIter.hasNext()) { + advanceKey(); + } + return resultValue; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public Reiterator copy() { + return new ShuffleReaderIterator(this); + } + + private void advanceKey() { + while (recordsIter.hasNext()) { + currentRecord = recordsIter.next(); + currentKey = currentRecord.getKey(); + if (startKey != null && + SHUFFLE_KEY_COMPARATOR.compare(currentKey, startKey) < 0) { + // This key is before the start of the range. Keep looking. + continue; + } + if (endKey != null && + SHUFFLE_KEY_COMPARATOR.compare(currentKey, endKey) >= 0) { + // This key is at or after the end of the range. Stop looking. + break; + } + // In range. + currentValuesIter = currentRecord.getValue().listIterator(); + return; + } + currentKey = null; + currentValuesIter = null; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReaderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReaderTest.java new file mode 100644 index 000000000000..87935a7bb3d8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleReaderTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * Tests of TestShuffleReader. + */ +@RunWith(JUnit4.class) +public class TestShuffleReaderTest { + static final String START_KEY = "ddd"; + static final String END_KEY = "mmm"; + + static final List> NO_ENTRIES = + Collections.emptyList(); + + static final List> IN_RANGE_ENTRIES = + Arrays.>asList( + new SimpleEntry<>("ddd", "in 1"), + new SimpleEntry<>("ddd", "in 1"), + new SimpleEntry<>("ddd", "in 1"), + new SimpleEntry<>("dddd", "in 2"), + new SimpleEntry<>("dddd", "in 2"), + new SimpleEntry<>("de", "in 3"), + new SimpleEntry<>("ee", "in 4"), + new SimpleEntry<>("ee", "in 4"), + new SimpleEntry<>("ee", "in 4"), + new SimpleEntry<>("ee", "in 4"), + new SimpleEntry<>("mm", "in 5")); + static final List> BEFORE_RANGE_ENTRIES = + Arrays.>asList( + new SimpleEntry<>("", "out 1"), + new SimpleEntry<>("dd", "out 2")); + static final List> AFTER_RANGE_ENTRIES = + Arrays.>asList( + new SimpleEntry<>("mmm", "out 3"), + new SimpleEntry<>("mmm", "out 3"), + new SimpleEntry<>("mmmm", "out 4"), + new SimpleEntry<>("mn", "out 5"), + new SimpleEntry<>("zzz", "out 6")); + static final List> OUT_OF_RANGE_ENTRIES = + new ArrayList<>(); + static { + OUT_OF_RANGE_ENTRIES.addAll(BEFORE_RANGE_ENTRIES); + OUT_OF_RANGE_ENTRIES.addAll(AFTER_RANGE_ENTRIES); + } + static final List> ALL_ENTRIES = new ArrayList<>(); + static { + ALL_ENTRIES.addAll(BEFORE_RANGE_ENTRIES); + ALL_ENTRIES.addAll(IN_RANGE_ENTRIES); + ALL_ENTRIES.addAll(AFTER_RANGE_ENTRIES); + } + + void runTest(List> expected, + List> outOfRange, + String startKey, + String endKey) { + TestShuffleReader shuffleReader = new TestShuffleReader(); + List> expectedCopy = new ArrayList<>(expected); + expectedCopy.addAll(outOfRange); + Collections.shuffle(expectedCopy); + for (Map.Entry entry : expectedCopy) { + shuffleReader.addEntry(entry.getKey(), entry.getValue()); + } + Iterator iter = shuffleReader.read(startKey, endKey); + List> actual = new ArrayList<>(); + while (iter.hasNext()) { + ShuffleEntry entry = iter.next(); + actual.add(new SimpleEntry<>(new String(entry.getKey()), + new String(entry.getValue()))); + } + try { + iter.next(); + Assert.fail("should have failed"); + } catch (NoSuchElementException exn) { + // Success. + } + Assert.assertEquals(expected, actual); + } + + @Test + public void testEmpty() { + runTest(NO_ENTRIES, NO_ENTRIES, null, null); + } + + @Test + public void testEmptyWithRange() { + runTest(NO_ENTRIES, NO_ENTRIES, START_KEY, END_KEY); + } + + @Test + public void testNonEmpty() { + runTest(ALL_ENTRIES, NO_ENTRIES, null, null); + } + + @Test + public void testNonEmptyWithAllInRange() { + runTest(IN_RANGE_ENTRIES, NO_ENTRIES, START_KEY, END_KEY); + } + + @Test + public void testNonEmptyWithSomeOutOfRange() { + runTest(IN_RANGE_ENTRIES, OUT_OF_RANGE_ENTRIES, START_KEY, END_KEY); + } + + @Test + public void testNonEmptyWithAllOutOfRange() { + runTest(NO_ENTRIES, OUT_OF_RANGE_ENTRIES, START_KEY, END_KEY); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleWriter.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleWriter.java new file mode 100644 index 000000000000..4fde0bbcdbaa --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TestShuffleWriter.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; + +import java.util.ArrayList; +import java.util.List; + +/** + * A fake implementation of a ShuffleEntryWriter, for testing. + */ +public class TestShuffleWriter implements ShuffleEntryWriter { + final List records = new ArrayList<>(); + final List sizes = new ArrayList<>(); + boolean closed = false; + + public TestShuffleWriter() { } + + @Override + public long put(ShuffleEntry entry) { + if (closed) { + throw new AssertionError("shuffle writer already closed"); + } + records.add(entry); + + long size = entry.length(); + sizes.add(size); + return size; + } + + @Override + public void close() { + if (closed) { + throw new AssertionError("shuffle writer already closed"); + } + closed = true; + } + + /** Returns the key/value records that were written to this ShuffleWriter. */ + public List getRecords() { + if (!closed) { + throw new AssertionError("shuffle writer not closed"); + } + return records; + } + + /** Returns the sizes in bytes of the records that were written to this ShuffleWriter. */ + public List getSizes() { + if (!closed) { + throw new AssertionError("shuffle writer not closed"); + } + return sizes; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactoryTest.java new file mode 100644 index 000000000000..9f9e63090a6e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkFactoryTest.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import javax.annotation.Nullable; + +/** + * Tests for TextSinkFactory. + */ +@RunWith(JUnit4.class) +public class TextSinkFactoryTest { + void runTestCreateTextSink(String filename, + @Nullable Boolean appendTrailingNewlines, + @Nullable String header, + @Nullable String footer, + CloudObject encoding, + Coder coder) + throws Exception { + CloudObject spec = CloudObject.forClassName("TextSink"); + addString(spec, PropertyNames.FILENAME, filename); + if (appendTrailingNewlines != null) { + addBoolean(spec, PropertyNames.APPEND_TRAILING_NEWLINES, appendTrailingNewlines); + } + if (header != null) { + addString(spec, PropertyNames.HEADER, header); + } + if (footer != null) { + addString(spec, PropertyNames.FOOTER, footer); + } + + com.google.api.services.dataflow.model.Sink cloudSink = + new com.google.api.services.dataflow.model.Sink(); + cloudSink.setSpec(spec); + cloudSink.setCodec(encoding); + + Sink sink = SinkFactory.create(PipelineOptionsFactory.create(), + cloudSink, + new BatchModeExecutionContext()); + Assert.assertThat(sink, new IsInstanceOf(TextSink.class)); + TextSink textSink = (TextSink) sink; + Assert.assertEquals(filename, textSink.namePrefix); + Assert.assertEquals( + appendTrailingNewlines == null ? true : appendTrailingNewlines, + textSink.appendTrailingNewlines); + Assert.assertEquals(header, textSink.header); + Assert.assertEquals(footer, textSink.footer); + Assert.assertEquals(coder, textSink.coder); + } + + @Test + public void testCreatePlainTextSink() throws Exception { + runTestCreateTextSink( + "/path/to/file.txt", null, null, null, + makeCloudEncoding("StringUtf8Coder"), + StringUtf8Coder.of()); + } + + @Test + public void testCreateRichTextSink() throws Exception { + runTestCreateTextSink( + "gs://bucket/path/to/file2.txt", false, "$$$", "***", + makeCloudEncoding("TextualIntegerCoder"), + TextualIntegerCoder.of()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkTest.java new file mode 100644 index 000000000000..d1b8b436a251 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSinkTest.java @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.Sink; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Tests for TextSink. + */ +@RunWith(JUnit4.class) +public class TextSinkTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + void runTestWriteFile(List elems, + @Nullable String header, + @Nullable String footer, + Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + TextSink> textSink = TextSink.createForTest( + tmpFile.getPath(), true, header, footer, coder); + List expected = new ArrayList<>(); + List actualSizes = new ArrayList<>(); + if (header != null) { + expected.add(header); + } + try (Sink.SinkWriter> writer = textSink.writer()) { + for (T elem : elems) { + actualSizes.add((int) writer.add(WindowedValue.valueInGlobalWindow(elem))); + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + String line = new String(encodedElem); + expected.add(line); + } + } + if (footer != null) { + expected.add(footer); + } + + BufferedReader reader = new BufferedReader(new FileReader(tmpFile)); + List actual = new ArrayList<>(); + List expectedSizes = new ArrayList<>(); + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + actual.add(line); + expectedSizes.add(line.length() + TextSink.NEWLINE.length); + } + if (header != null) { + expectedSizes.remove(0); + } + if (footer != null) { + expectedSizes.remove(expectedSizes.size() - 1); + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSizes, actualSizes); + } + + @Test + public void testWriteEmptyFile() throws Exception { + runTestWriteFile(Collections.emptyList(), null, null, + StringUtf8Coder.of()); + } + + @Test + public void testWriteEmptyFileWithHeaderAndFooter() throws Exception { + runTestWriteFile(Collections.emptyList(), "the head", "the foot", + StringUtf8Coder.of()); + } + + @Test + public void testWriteNonEmptyFile() throws Exception { + List lines = Arrays.asList( + "", + " hi there ", + "bob", + "", + " ", + "--zowie!--", + ""); + runTestWriteFile(lines, null, null, StringUtf8Coder.of()); + } + + @Test + public void testWriteNonEmptyFileWithHeaderAndFooter() throws Exception { + List lines = Arrays.asList( + "", + " hi there ", + "bob", + "", + " ", + "--zowie!--", + ""); + runTestWriteFile(lines, "the head", "the foot", StringUtf8Coder.of()); + } + + @Test + public void testWriteNonEmptyNonStringFile() throws Exception { + runTestWriteFile(TestUtils.INTS, null, null, TextualIntegerCoder.of()); + } + + // TODO: sharded filenames + // TODO: not appending newlines + // TODO: writing to GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactoryTest.java new file mode 100644 index 000000000000..2fa50b567e78 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceFactoryTest.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.BatchModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import javax.annotation.Nullable; + +/** + * Tests for TextSourceFactory. + */ +@RunWith(JUnit4.class) +public class TextSourceFactoryTest { + void runTestCreateTextSource(String filename, + @Nullable Boolean stripTrailingNewlines, + @Nullable Long start, + @Nullable Long end, + CloudObject encoding, + Coder coder) + throws Exception { + CloudObject spec = CloudObject.forClassName("TextSource"); + addString(spec, "filename", filename); + if (stripTrailingNewlines != null) { + addBoolean(spec, "strip_trailing_newlines", stripTrailingNewlines); + } + if (start != null) { + addLong(spec, "start_offset", start); + } + if (end != null) { + addLong(spec, "end_offset", end); + } + + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + cloudSource.setSpec(spec); + cloudSource.setCodec(encoding); + + Source source = SourceFactory.create(PipelineOptionsFactory.create(), + cloudSource, + new BatchModeExecutionContext()); + Assert.assertThat(source, new IsInstanceOf(TextSource.class)); + TextSource textSource = (TextSource) source; + Assert.assertEquals(filename, textSource.filename); + Assert.assertEquals( + stripTrailingNewlines == null ? true : stripTrailingNewlines, + textSource.stripTrailingNewlines); + Assert.assertEquals(start, textSource.startPosition); + Assert.assertEquals(end, textSource.endPosition); + Assert.assertEquals(coder, textSource.coder); + } + + @Test + public void testCreatePlainTextSource() throws Exception { + runTestCreateTextSource( + "/path/to/file.txt", null, null, null, + makeCloudEncoding("StringUtf8Coder"), + StringUtf8Coder.of()); + } + + @Test + public void testCreateRichTextSource() throws Exception { + runTestCreateTextSource( + "gs://bucket/path/to/file2.txt", false, 200L, 500L, + makeCloudEncoding("TextualIntegerCoder"), + TextualIntegerCoder.of()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceTest.java new file mode 100644 index 000000000000..8aee7aaf0052 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/TextSourceTest.java @@ -0,0 +1,581 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static org.hamcrest.Matchers.greaterThan; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.runners.worker.TextSource.TextFileIterator; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for TextSource. + */ +@RunWith(JUnit4.class) +public class TextSourceTest { + private static final String[] fileContent = {"First line\n", + "Second line\r\n", + "Third line"}; + private static final long TOTAL_BYTES_COUNT; + + static { + long sumLen = 0L; + for (String s : fileContent) { + sumLen += s.length(); + } + TOTAL_BYTES_COUNT = sumLen; + } + + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private File initTestFile() throws IOException { + File tmpFile = tmpFolder.newFile(); + FileOutputStream output = new FileOutputStream(tmpFile); + for (String s : fileContent) { + output.write(s.getBytes()); + } + output.close(); + + return tmpFile; + } + + @Test + public void testReadEmptyFile() throws Exception { + TextSource textSource = new TextSource<>( + "/dev/null", true, null, null, StringUtf8Coder.of()); + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertFalse(iterator.hasNext()); + } + } + + @Test + public void testStrippedNewlines() throws Exception { + testNewlineHandling("\r", true); + testNewlineHandling("\r\n", true); + testNewlineHandling("\n", true); + } + + @Test + public void testStrippedNewlinesAtEndOfReadBuffer() throws Exception { + boolean stripNewLines = true; + StringBuilder payload = new StringBuilder(); + for (int i = 0; i < TextSource.BUF_SIZE - 2; ++i) { + payload.append('a'); + } + String[] lines = {payload.toString(), payload.toString()}; + testStringPayload(lines , "\r", stripNewLines); + testStringPayload(lines , "\r\n", stripNewLines); + testStringPayload(lines , "\n", stripNewLines); + } + + @Test + public void testUnstrippedNewlines() throws Exception { + testNewlineHandling("\r", false); + testNewlineHandling("\r\n", false); + testNewlineHandling("\n", false); + } + + @Test + public void testUnstrippedNewlinesAtEndOfReadBuffer() throws Exception { + boolean stripNewLines = false; + StringBuilder payload = new StringBuilder(); + for (int i = 0; i < TextSource.BUF_SIZE - 2; ++i) { + payload.append('a'); + } + String[] lines = {payload.toString(), payload.toString()}; + testStringPayload(lines , "\r", stripNewLines); + testStringPayload(lines , "\r\n", stripNewLines); + testStringPayload(lines , "\n", stripNewLines); + } + + @Test + public void testStartPosition() throws Exception { + File tmpFile = initTestFile(); + + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, 11L, null, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertEquals("Second line\r\n", iterator.next()); + Assert.assertEquals("Third line", iterator.next()); + Assert.assertFalse(iterator.hasNext()); + // The first '1' in the array represents the reading of '\n' between first and + // second line, to confirm that we are reading from the beginning of a record. + Assert.assertEquals(Arrays.asList(1, 13, 10), observer.getActualSizes()); + } + } + + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, 20L, null, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertEquals("Third line", iterator.next()); + Assert.assertFalse(iterator.hasNext()); + // The first '5' in the array represents the reading of a portion of the second + // line, which had to be read to find the beginning of the third line. + Assert.assertEquals(Arrays.asList(5, 10), observer.getActualSizes()); + } + } + + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), true, 0L, 20L, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertEquals("First line", iterator.next()); + Assert.assertEquals("Second line", iterator.next()); + Assert.assertFalse(iterator.hasNext()); + Assert.assertEquals(Arrays.asList(11, 13), observer.getActualSizes()); + } + } + + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), true, 1L, 20L, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertEquals("Second line", iterator.next()); + Assert.assertFalse(iterator.hasNext()); + // The first '11' in the array represents the reading of the entire first + // line, which had to be read to find the beginning of the second line. + Assert.assertEquals(Arrays.asList(11, 13), observer.getActualSizes()); + } + } + } + + @Test + public void testUtf8Handling() throws Exception { + File tmpFile = tmpFolder.newFile(); + FileOutputStream output = new FileOutputStream(tmpFile); + // first line: €\n + // second line: ¢\n + output.write(new byte[]{(byte) 0xE2, (byte) 0x82, (byte) 0xAC, '\n', + (byte) 0xC2, (byte) 0xA2, '\n'}); + output.close(); + + { + // 3L is after the first line if counting codepoints, but within + // the first line if counting chars. So correct behavior is to return + // just one line, since offsets are in chars, not codepoints. + TextSource textSource = new TextSource<>( + tmpFile.getPath(), true, 0L, 3L, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertArrayEquals("€".getBytes("UTF-8"), + iterator.next().getBytes("UTF-8")); + Assert.assertFalse(iterator.hasNext()); + Assert.assertEquals(Arrays.asList(4), observer.getActualSizes()); + } + } + + { + // Starting location is mid-way into a codepoint. + // Ensures we don't fail when skipping over an incomplete codepoint. + TextSource textSource = new TextSource<>( + tmpFile.getPath(), true, 2L, null, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (Source.SourceIterator iterator = textSource.iterator()) { + Assert.assertArrayEquals("¢".getBytes("UTF-8"), + iterator.next().getBytes("UTF-8")); + Assert.assertFalse(iterator.hasNext()); + // The first '3' in the array represents the reading of a portion of the first + // line, which had to be read to find the beginning of the second line. + Assert.assertEquals(Arrays.asList(3, 3), observer.getActualSizes()); + } + } + } + + private void testNewlineHandling(String separator, boolean stripNewlines) + throws Exception { + File tmpFile = tmpFolder.newFile(); + PrintStream writer = + new PrintStream( + new FileOutputStream(tmpFile)); + List expected = Arrays.asList( + "", + " hi there ", + "bob", + "", + " ", + "--zowie!--", + ""); + List expectedSizes = new ArrayList<>(); + for (String line : expected) { + writer.print(line); + writer.print(separator); + expectedSizes.add(line.length() + separator.length()); + } + writer.close(); + + TextSource textSource = new TextSource<>( + tmpFile.getPath(), stripNewlines, null, null, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + List actual = new ArrayList<>(); + try (Source.SourceIterator iterator = textSource.iterator()) { + while (iterator.hasNext()) { + actual.add(iterator.next()); + } + } + + if (stripNewlines) { + Assert.assertEquals(expected, actual); + } else { + List unstripped = new LinkedList<>(); + for (String s : expected) { + unstripped.add(s + separator); + } + Assert.assertEquals(unstripped, actual); + } + + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + private void testStringPayload( + String[] lines, String separator, boolean stripNewlines) + throws Exception { + File tmpFile = tmpFolder.newFile(); + List expected = new ArrayList<>(); + PrintStream writer = + new PrintStream( + new FileOutputStream(tmpFile)); + for (String line : lines) { + writer.print(line); + writer.print(separator); + expected.add(stripNewlines ? line : line + separator); + } + writer.close(); + + TextSource textSource = new TextSource<>( + tmpFile.getPath(), stripNewlines, null, null, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + List actual = new ArrayList<>(); + try (Source.SourceIterator iterator = textSource.iterator()) { + while (iterator.hasNext()) { + actual.add(iterator.next()); + } + } + Assert.assertEquals(expected, actual); + } + + @Test + public void testCloneIteratorWithEndPositionAndFinalBytesInBuffer() + throws Exception { + String line = "a\n"; + boolean stripNewlines = false; + File tmpFile = tmpFolder.newFile(); + List expected = new ArrayList<>(); + PrintStream writer = new PrintStream(new FileOutputStream(tmpFile)); + // Write 5x the size of the buffer and 10 extra trailing bytes + for (long bytesWritten = 0; + bytesWritten < TextSource.BUF_SIZE * 3 + 10; ) { + writer.print(line); + expected.add(line); + bytesWritten += line.length(); + } + writer.close(); + Long fileSize = tmpFile.length(); + + TextSource textSource = new TextSource<>( + tmpFile.getPath(), stripNewlines, + null, fileSize, StringUtf8Coder.of()); + + List actual = new ArrayList<>(); + Source.SourceIterator iterator = textSource.iterator(); + while (iterator.hasNext()) { + actual.add(iterator.next()); + iterator = iterator.copy(); + } + Assert.assertEquals(expected, actual); + } + + @Test + public void testNonStringCoders() throws Exception { + File tmpFile = tmpFolder.newFile(); + PrintStream writer = + new PrintStream( + new FileOutputStream(tmpFile)); + List expected = TestUtils.INTS; + List expectedSizes = new ArrayList<>(); + for (Integer elem : expected) { + byte[] encodedElem = + CoderUtils.encodeToByteArray(TextualIntegerCoder.of(), elem); + writer.print(elem); + writer.print("\n"); + expectedSizes.add(1 + encodedElem.length); + } + writer.close(); + + TextSource textSource = new TextSource<>( + tmpFile.getPath(), true, null, null, TextualIntegerCoder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + List actual = new ArrayList<>(); + try (Source.SourceIterator iterator = textSource.iterator()) { + while (iterator.hasNext()) { + actual.add(iterator.next()); + } + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + @Test + public void testGetApproximatePosition() throws Exception { + File tmpFile = initTestFile(); + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, 0L, null, StringUtf8Coder.of()); + + try (Source.SourceIterator iterator = textSource.iterator()) { + ApproximateProgress progress = + sourceProgressToCloudProgress(iterator.getProgress()); + Assert.assertEquals(0L, + progress.getPosition().getByteOffset().longValue()); + iterator.next(); + progress = sourceProgressToCloudProgress(iterator.getProgress()); + Assert.assertEquals(11L, + progress.getPosition().getByteOffset().longValue()); + iterator.next(); + progress = sourceProgressToCloudProgress(iterator.getProgress()); + Assert.assertEquals(24L, + progress.getPosition().getByteOffset().longValue()); + iterator.next(); + progress = sourceProgressToCloudProgress(iterator.getProgress()); + Assert.assertEquals(34L, + progress.getPosition().getByteOffset().longValue()); + Assert.assertFalse(iterator.hasNext()); + } + } + + @Test + public void testUpdateStopPosition() throws Exception { + final long end = 10L; // in the first line + final long stop = 14L; // in the middle of the second line + File tmpFile = initTestFile(); + + com.google.api.services.dataflow.model.Position proposedStopPosition = + new com.google.api.services.dataflow.model.Position(); + + // Illegal proposed stop position, no update. + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, null, null, + StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + } + } + + proposedStopPosition.setByteOffset(stop); + + // Successful update. + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, null, null, + StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + Assert.assertNull(iterator.getEndOffset()); + Assert.assertEquals( + stop, + sourcePositionToCloudPosition( + iterator.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))) + .getByteOffset().longValue()); + Assert.assertEquals(stop, iterator.getEndOffset().longValue()); + Assert.assertEquals(fileContent[0], iterator.next()); + Assert.assertEquals(fileContent[1], iterator.next()); + Assert.assertFalse(iterator.hasNext()); + Assert.assertEquals(Arrays.asList(fileContent[0].length(), + fileContent[1].length()), + observer.getActualSizes()); + } + } + + // Proposed stop position is before the current position, no update. + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, null, null, + StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + Assert.assertEquals(fileContent[0], iterator.next()); + Assert.assertEquals(fileContent[1], iterator.next()); + Assert.assertThat(sourceProgressToCloudProgress(iterator.getProgress()) + .getPosition().getByteOffset(), + greaterThan(stop)); + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + Assert.assertNull(iterator.getEndOffset()); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(fileContent[2], iterator.next()); + Assert.assertEquals(Arrays.asList(fileContent[0].length(), + fileContent[1].length(), + fileContent[2].length()), + observer.getActualSizes()); + } + } + + // Proposed stop position is after the current stop (end) position, no update. + { + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, null, end, StringUtf8Coder.of()); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(textSource); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + Assert.assertEquals(fileContent[0], iterator.next()); + Assert.assertNull(iterator.updateStopPosition( + cloudProgressToSourceProgress(createApproximateProgress(proposedStopPosition)))); + Assert.assertEquals(end, iterator.getEndOffset().longValue()); + Assert.assertFalse(iterator.hasNext()); + Assert.assertEquals(Arrays.asList(fileContent[0].length()), + observer.getActualSizes()); + } + } + } + + @Test + public void testUpdateStopPositionExhaustive() throws Exception { + File tmpFile = initTestFile(); + + // Checks for every possible position in the file, that either we fail to + // "updateStop" at it, or we succeed and then reading both halves together + // yields the original file with no missed records or duplicates. + for (long start = 0; start < TOTAL_BYTES_COUNT - 1; start++) { + for (long end = start + 1; end < TOTAL_BYTES_COUNT; end++) { + for (long stop = start; stop <= end; stop++) { + stopPositionTestInternal(start, end, + stop, tmpFile); + } + } + } + + // Test with null start/end positions. + for (long stop = 0L; stop < TOTAL_BYTES_COUNT; stop++) { + stopPositionTestInternal(null, null, stop, tmpFile); + } + } + + private void stopPositionTestInternal(Long startOffset, + Long endOffset, + Long stopOffset, + File tmpFile) throws Exception { + String readWithoutSplit; + String readWithSplit1, readWithSplit2; + StringBuilder accumulatedRead = new StringBuilder(); + + // Read from source without split attempts. + TextSource textSource = new TextSource<>( + tmpFile.getPath(), false, startOffset, endOffset, + StringUtf8Coder.of()); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + while (iterator.hasNext()) { + accumulatedRead.append((String) iterator.next()); + } + readWithoutSplit = accumulatedRead.toString(); + } + + // Read the first half of the split. + textSource = new TextSource<>( + tmpFile.getPath(), false, startOffset, stopOffset, + StringUtf8Coder.of()); + accumulatedRead = new StringBuilder(); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + while (iterator.hasNext()) { + accumulatedRead.append((String) iterator.next()); + } + readWithSplit1 = accumulatedRead.toString(); + } + + // Read the second half of the split. + textSource = new TextSource<>( + tmpFile.getPath(), false, stopOffset, endOffset, + StringUtf8Coder.of()); + accumulatedRead = new StringBuilder(); + + try (TextFileIterator iterator = (TextFileIterator) textSource.iterator()) { + while (iterator.hasNext()) { + accumulatedRead.append((String) iterator.next()); + } + readWithSplit2 = accumulatedRead.toString(); + } + + Assert.assertEquals(readWithoutSplit, readWithSplit1 + readWithSplit2); + } + + private ApproximateProgress createApproximateProgress( + com.google.api.services.dataflow.model.Position position) { + return new ApproximateProgress().setPosition(position); + } + + // TODO: sharded filenames + // TODO: reading from GCS +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceTest.java new file mode 100644 index 000000000000..3a360d8d24ad --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/UngroupedShuffleSourceTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ShuffleEntry; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Tests for UngroupedShuffleSource. + */ +@RunWith(JUnit4.class) +public class UngroupedShuffleSourceTest { + static final Instant timestamp = new Instant(123000); + static final IntervalWindow window = new IntervalWindow(timestamp, timestamp.plus(1000)); + + byte[] asShuffleKey(long seqNum) throws Exception { + return CoderUtils.encodeToByteArray(BigEndianLongCoder.of(), seqNum); + } + + byte[] asShuffleValue(Integer value) throws Exception { + return CoderUtils.encodeToByteArray( + WindowedValue.getFullCoder(BigEndianIntegerCoder.of(), IntervalWindow.getCoder()), + WindowedValue.of(value, timestamp, Lists.newArrayList(window))); + } + + void runTestReadShuffleSource(List expected) throws Exception { + UngroupedShuffleSource> shuffleSource = + new UngroupedShuffleSource<>( + PipelineOptionsFactory.create(), + null, null, null, + WindowedValue.getFullCoder(BigEndianIntegerCoder.of(), IntervalWindow.getCoder())); + ExecutorTestUtils.TestSourceObserver observer = + new ExecutorTestUtils.TestSourceObserver(shuffleSource); + + TestShuffleReader shuffleReader = new TestShuffleReader(); + List expectedSizes = new ArrayList<>(); + long seqNum = 0; + for (Integer value : expected) { + byte[] shuffleKey = asShuffleKey(seqNum++); + byte[] shuffleValue = asShuffleValue(value); + shuffleReader.addEntry(shuffleKey, shuffleValue); + + ShuffleEntry record = new ShuffleEntry(shuffleKey, null, shuffleValue); + expectedSizes.add(record.length()); + } + + List actual = new ArrayList<>(); + try (Source.SourceIterator> iter = + shuffleSource.iterator(shuffleReader)) { + while (iter.hasNext()) { + Assert.assertTrue(iter.hasNext()); + Assert.assertTrue(iter.hasNext()); + WindowedValue elem = iter.next(); + actual.add(elem.getValue()); + } + Assert.assertFalse(iter.hasNext()); + Assert.assertFalse(iter.hasNext()); + try { + iter.next(); + Assert.fail("should have failed"); + } catch (NoSuchElementException exn) { + // As expected. + } + } + + Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSizes, observer.getActualSizes()); + } + + @Test + public void testReadEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(TestUtils.NO_INTS); + } + + @Test + public void testReadNonEmptyShuffleSource() throws Exception { + runTestReadShuffleSource(TestUtils.INTS); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatterTest.java new file mode 100644 index 000000000000..065092aeaab1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingFormatterTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker.logging; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.testing.RestoreMappedDiagnosticContext; +import com.google.common.collect.ImmutableMap; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.MDC; + +import java.util.logging.Level; +import java.util.logging.LogRecord; + +/** Unit tests for {@link DataflowWorkerLoggingFormatter}. */ +@RunWith(JUnit4.class) +public class DataflowWorkerLoggingFormatterTest { + @Rule public TestRule restoreMDC = new RestoreMappedDiagnosticContext(); + + @Test + public void testWithUnsetValuesInMDC() { + assertEquals( + "1970-01-01T00:00:00.001Z INFO unknown unknown unknown 2 LoggerName " + + "test.message\n", + new DataflowWorkerLoggingFormatter().format( + createLogRecord("test.message", null))); + } + + @Test + public void testWithMessage() { + MDC.setContextMap(ImmutableMap.of( + "dataflow.jobId", "testJobId", + "dataflow.workerId", "testWorkerId", + "dataflow.workId", "testWorkId")); + assertEquals( + "1970-01-01T00:00:00.001Z INFO testJobId testWorkerId testWorkId 2 LoggerName " + + "test.message\n", + new DataflowWorkerLoggingFormatter().format( + createLogRecord("test.message", null))); + } + + @Test + public void testWithMessageAndException() { + MDC.setContextMap(ImmutableMap.of( + "dataflow.jobId", "testJobId", + "dataflow.workerId", "testWorkerId", + "dataflow.workId", "testWorkId")); + assertEquals( + "1970-01-01T00:00:00.001Z INFO testJobId testWorkerId testWorkId 2 LoggerName " + + "test.message\n" + + "java.lang.Throwable: exception.test.message\n" + + "\tat declaringClass1.method1(file1.java:1)\n" + + "\tat declaringClass2.method2(file2.java:1)\n" + + "\tat declaringClass3.method3(file3.java:1)\n", + new DataflowWorkerLoggingFormatter().format( + createLogRecord("test.message", createThrowable()))); + } + + @Test + public void testWithException() { + MDC.setContextMap(ImmutableMap.of( + "dataflow.jobId", "testJobId", + "dataflow.workerId", "testWorkerId", + "dataflow.workId", "testWorkId")); + assertEquals( + "1970-01-01T00:00:00.001Z INFO testJobId testWorkerId testWorkId 2 LoggerName null\n" + + "java.lang.Throwable: exception.test.message\n" + + "\tat declaringClass1.method1(file1.java:1)\n" + + "\tat declaringClass2.method2(file2.java:1)\n" + + "\tat declaringClass3.method3(file3.java:1)\n", + new DataflowWorkerLoggingFormatter().format( + createLogRecord(null, createThrowable()))); + } + + @Test + public void testWithoutExceptionOrMessage() { + MDC.setContextMap(ImmutableMap.of( + "dataflow.jobId", "testJobId", + "dataflow.workerId", "testWorkerId", + "dataflow.workId", "testWorkId")); + assertEquals( + "1970-01-01T00:00:00.001Z INFO testJobId testWorkerId testWorkId 2 LoggerName null\n", + new DataflowWorkerLoggingFormatter().format( + createLogRecord(null, null))); + } + + /** + * @return A throwable with a fixed stack trace. + */ + private Throwable createThrowable() { + Throwable throwable = new Throwable("exception.test.message"); + throwable.setStackTrace(new StackTraceElement[]{ + new StackTraceElement("declaringClass1", "method1", "file1.java", 1), + new StackTraceElement("declaringClass2", "method2", "file2.java", 1), + new StackTraceElement("declaringClass3", "method3", "file3.java", 1), + }); + return throwable; + } + + /** + * Creates and returns a LogRecord with a given message and throwable. + * + * @param message The message to place in the {@link LogRecord} + * @param throwable The throwable to place in the {@link LogRecord} + * @return A {@link LogRecord} with the given message and throwable. + */ + private LogRecord createLogRecord(String message, Throwable throwable) { + LogRecord logRecord = new LogRecord(Level.INFO, message); + logRecord.setLoggerName("LoggerName"); + logRecord.setMillis(1L); + logRecord.setThreadID(2); + logRecord.setThrown(throwable); + return logRecord; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializerTest.java new file mode 100644 index 000000000000..71e51f430d88 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/worker/logging/DataflowWorkerLoggingInitializerTest.java @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.worker.logging; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.List; +import java.util.logging.ConsoleHandler; +import java.util.logging.FileHandler; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogManager; +import java.util.logging.Logger; + +/** Unit tests for {@link DataflowWorkerLoggingInitializer}. */ +@RunWith(JUnit4.class) +public class DataflowWorkerLoggingInitializerTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + @Mock LogManager mockLogManager; + @Mock Logger mockRootLogger; + @Mock Handler mockHandler; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + when(mockLogManager.getLogger("")).thenReturn(mockRootLogger); + when(mockRootLogger.getHandlers()).thenReturn(new Handler[]{ mockHandler }); + } + + @Test + public void testWithDefaults() { + ArgumentCaptor argument = ArgumentCaptor.forClass(Handler.class); + + new DataflowWorkerLoggingInitializer().initialize(mockLogManager); + verify(mockLogManager).getLogger(""); + verify(mockLogManager).reset(); + verify(mockRootLogger).getHandlers(); + verify(mockRootLogger).removeHandler(mockHandler); + verify(mockRootLogger).setLevel(Level.INFO); + verify(mockRootLogger, times(2)).addHandler(argument.capture()); + verifyNoMoreInteractions(mockLogManager, mockRootLogger); + + List handlers = argument.getAllValues(); + assertTrue(isConsoleHandler(handlers.get(0), Level.INFO)); + assertTrue(isFileHandler(handlers.get(1), Level.INFO)); + } + + @Test + public void testWithOverrides() { + ArgumentCaptor argument = ArgumentCaptor.forClass(Handler.class); + System.setProperty("dataflow.worker.logging.level", "WARNING"); + + new DataflowWorkerLoggingInitializer().initialize(mockLogManager); + verify(mockLogManager).getLogger(""); + verify(mockLogManager).reset(); + verify(mockRootLogger).getHandlers(); + verify(mockRootLogger).removeHandler(mockHandler); + verify(mockRootLogger).setLevel(Level.WARNING); + verify(mockRootLogger, times(2)).addHandler(argument.capture()); + verifyNoMoreInteractions(mockLogManager, mockRootLogger); + + List handlers = argument.getAllValues(); + assertTrue(isConsoleHandler(handlers.get(0), Level.WARNING)); + assertTrue(isFileHandler(handlers.get(1), Level.WARNING)); + } + + private boolean isConsoleHandler(Handler handler, Level level) { + return handler instanceof ConsoleHandler + && level.equals(handler.getLevel()) + && handler.getFormatter() instanceof DataflowWorkerLoggingFormatter; + } + + private boolean isFileHandler(Handler handler, Level level) { + return handler instanceof FileHandler + && level.equals(handler.getLevel()) + && handler.getFormatter() instanceof DataflowWorkerLoggingFormatter; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java new file mode 100644 index 000000000000..3f4e33d63268 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java @@ -0,0 +1,240 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +import java.util.Collection; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; + +/** + * This {@link TestRule} enables the ability to capture JUL logging events during test execution and + * assert expectations that they contain certain messages (with or without {@link Throwable}) at + * certain log levels. For logs generated via the SLF4J logging frontend, the JUL backend must be + * used. + */ +public class ExpectedLogs extends ExternalResource { + /** + * Returns a {@link TestRule} which captures logs for the given class. + * + * @param klass The class to capture logs for. + * @return A {@link ExpectedLogs} test rule. + */ + public static ExpectedLogs none(Class klass) { + return new ExpectedLogs(klass); + } + + /** + * Expect a logging event at the trace level with the given message. + * + * @param substring The message to match against. + */ + public void expectTrace(String substring) { + expect(Level.FINEST, substring); + } + + /** + * Expect a logging event at the trace level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void expectTrace(String substring, Throwable t) { + expect(Level.FINEST, substring, t); + } + + /** + * Expect a logging event at the debug level with the given message. + * + * @param substring The message to match against. + */ + public void expectDebug(String substring) { + expect(Level.FINE, substring); + } + + /** + * Expect a logging event at the debug level with the given message and throwable. + * + * @param message The message to match against. + * @param t The throwable to match against. + */ + public void expectDebug(String message, Throwable t) { + expect(Level.FINE, message, t); + } + + /** + * Expect a logging event at the info level with the given message. + * @param substring The message to match against. + */ + public void expectInfo(String substring) { + expect(Level.INFO, substring); + } + + /** + * Expect a logging event at the info level with the given message and throwable. + * + * @param message The message to match against. + * @param t The throwable to match against. + */ + public void expectInfo(String message, Throwable t) { + expect(Level.INFO, message, t); + } + + /** + * Expect a logging event at the warn level with the given message. + * + * @param substring The message to match against. + */ + public void expectWarn(String substring) { + expect(Level.WARNING, substring); + } + + /** + * Expect a logging event at the warn level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void expectWarn(String substring, Throwable t) { + expect(Level.WARNING, substring, t); + } + + /** + * Expect a logging event at the error level with the given message. + * + * @param substring The message to match against. + */ + public void expectError(String substring) { + expect(Level.SEVERE, substring); + } + + /** + * Expect a logging event at the error level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void expectError(String substring, Throwable t) { + expect(Level.SEVERE, substring, t); + } + + private void expect(final Level level, final String substring) { + expectations.add(new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText(String.format( + "Expected log message of level [%s] containing message [%s]", level, substring)); + } + + @Override + protected boolean matchesSafely(LogRecord item) { + return level.equals(item.getLevel()) + && item.getMessage().contains(substring); + } + }); + } + + private void expect(final Level level, final String substring, final Throwable throwable) { + expectations.add(new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText(String.format( + "Expected log message of level [%s] containg message [%s] with exception [%s] " + + "containing message [%s]", + level, substring, throwable.getClass(), throwable.getMessage())); + } + + @Override + protected boolean matchesSafely(LogRecord item) { + return level.equals(item.getLevel()) + && item.getMessage().contains(substring) + && item.getThrown().getClass().equals(throwable.getClass()) + && item.getThrown().getMessage().contains(throwable.getMessage()); + } + }); + } + + @Override + protected void before() throws Throwable { + previousLevel = log.getLevel(); + log.setLevel(Level.ALL); + log.addHandler(logSaver); + } + + @Override + protected void after() { + log.removeHandler(logSaver); + log.setLevel(previousLevel); + Collection> missingExpecations = Lists.newArrayList(); + FOUND: for (Matcher expectation : expectations) { + for (LogRecord log : logSaver.getLogs()) { + if (expectation.matches(log)) { + continue FOUND; + } + } + missingExpecations.add(expectation); + } + + if (!missingExpecations.isEmpty()) { + fail(String.format("Missed logging expectations: %s", missingExpecations)); + } + } + + private final Logger log; + private final LogSaver logSaver; + private final Collection> expectations; + private Level previousLevel; + + private ExpectedLogs(Class klass) { + log = Logger.getLogger(klass.getName()); + logSaver = new LogSaver(); + expectations = Lists.newArrayList(); + } + + /** + * A JUL logging {@link Handler} that records all logging events which are passed to it. + */ + private static class LogSaver extends Handler { + Collection logRecords = Lists.newArrayList(); + + public Collection getLogs() { + return logRecords; + } + + @Override + public void publish(LogRecord record) { + logRecords.add(record); + } + + @Override + public void flush() {} + + @Override + public void close() throws SecurityException {} + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java new file mode 100644 index 000000000000..4d9cd0e76639 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Random; + +/** Tests for {@link FastNanoClockAndSleeper}. */ +@RunWith(JUnit4.class) +public class ExpectedLogsTest { + private static final Logger LOG = LoggerFactory.getLogger(ExpectedLogsTest.class); + + private ExpectedLogs expectedLogs; + + @Before + public void setUp() { + expectedLogs = ExpectedLogs.none(ExpectedLogsTest.class); + } + + @Test + public void testWhenNoExpectations() throws Throwable { + expectedLogs.before(); + LOG.error(generateRandomString()); + expectedLogs.after(); + } + + @Test + public void testWhenExpectationIsMatchedFully() throws Throwable { + String expected = generateRandomString(); + expectedLogs.before(); + expectedLogs.expectError(expected); + LOG.error(expected); + expectedLogs.after(); + } + + + @Test + public void testWhenExpectationIsMatchedPartially() throws Throwable { + String expected = generateRandomString(); + expectedLogs.before(); + expectedLogs.expectError(expected); + LOG.error("Extra stuff around expected " + expected + " blah"); + expectedLogs.after(); + } + + @Test + public void testWhenExpectationIsMatchedWithExceptionBeingLogged() throws Throwable { + String expected = generateRandomString(); + expectedLogs.before(); + expectedLogs.expectError(expected); + LOG.error(expected, new IOException()); + expectedLogs.after(); + } + + @Test(expected = AssertionError.class) + public void testWhenExpectationIsNotMatched() throws Throwable { + String expected = generateRandomString(); + expectedLogs.before(); + expectedLogs.expectError(expected); + expectedLogs.after(); + } + + @Test + public void testLogCaptureOccursAtLowestLogLevel() throws Throwable { + String expected = generateRandomString(); + expectedLogs.before(); + expectedLogs.expectTrace(expected); + LOG.trace(expected); + expectedLogs.after(); + } + + // Generates random strings of 10 characters. + private static String generateRandomString() { + Random random = new Random(); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 10; i++) { + builder.append('a' + (char) random.nextInt(26)); + } + return builder.toString(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java new file mode 100644 index 000000000000..e9fa9839e737 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; + +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +/** + * This object quickly moves time forward based upon how much it has been asked to sleep, + * without actually sleeping, to simulate the backoff. + */ +public class FastNanoClockAndSleeper extends ExternalResource + implements NanoClock, Sleeper, TestRule { + private long fastNanoTime; + + @Override + public long nanoTime() { + return fastNanoTime; + } + + @Override + protected void before() throws Throwable { + fastNanoTime = NanoClock.SYSTEM.nanoTime(); + } + + @Override + public void sleep(long millis) throws InterruptedException { + fastNanoTime += millis * 1000000L; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java new file mode 100644 index 000000000000..3c9275f54a23 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.TimeUnit; + +/** Tests for {@link FastNanoClockAndSleeper}. */ +@RunWith(JUnit4.class) +public class FastNanoClockAndSleeperTest { + @Rule public FastNanoClockAndSleeper fastNanoClockAndSleeper = new FastNanoClockAndSleeper(); + + @Test + public void testClockAndSleeper() throws Exception { + long sleepTimeMs = TimeUnit.SECONDS.toMillis(30); + long sleepTimeNano = TimeUnit.MILLISECONDS.toNanos(sleepTimeMs); + long fakeTimeNano = fastNanoClockAndSleeper.nanoTime(); + long startTimeNano = System.nanoTime(); + fastNanoClockAndSleeper.sleep(sleepTimeMs); + long maxTimeNano = startTimeNano + TimeUnit.SECONDS.toNanos(1); + // Verify that actual time didn't progress as much as was requested + assertTrue(System.nanoTime() < maxTimeNano); + // Verify that the fake time did go up by the amount requested + assertEquals(fakeTimeNano + sleepTimeNano, fastNanoClockAndSleeper.nanoTime()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java new file mode 100644 index 000000000000..675d7ac11361 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import org.joda.time.DateTimeUtils; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +/** + * This {@link TestRule} resets the date time provider in Joda to the system date + * time provider after tests. + */ +public class ResetDateTimeProvider extends ExternalResource { + public void setDateTimeFixed(String iso8601) { + setDateTimeFixed(ISODateTimeFormat.dateTime().parseMillis(iso8601)); + } + + public void setDateTimeFixed(long millis) { + DateTimeUtils.setCurrentMillisFixed(millis); + } + + @Override + protected void after() { + DateTimeUtils.setCurrentMillisSystem(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java new file mode 100644 index 000000000000..5aa96835676c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import org.joda.time.DateTimeUtils; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ResetDateTimeProvider}. */ +@RunWith(JUnit4.class) +public class ResetDateTimeProviderTest { + private static final String TEST_TIME = "2014-12-08T19:07:06.698Z"; + private static final long TEST_TIME_MS = + ISODateTimeFormat.dateTime().parseMillis(TEST_TIME); + + @Rule public ResetDateTimeProvider resetDateTimeProviderRule = new ResetDateTimeProvider(); + + /* + * Since these tests can run out of order, both test A and B change the provider + * and verify that the provider was reset. + */ + @Test + public void testResetA() { + assertNotEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + resetDateTimeProviderRule.setDateTimeFixed(TEST_TIME); + assertEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + } + + @Test + public void testResetB() { + assertNotEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + resetDateTimeProviderRule.setDateTimeFixed(TEST_TIME); + assertEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContext.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContext.java new file mode 100644 index 000000000000..f0bdb9e21704 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContext.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; + +import org.junit.rules.ExternalResource; +import org.slf4j.MDC; + +import java.util.Map; + +/** + * Saves and restores the current MDC for tests. + */ +public class RestoreMappedDiagnosticContext extends ExternalResource { + private Map previousValue; + + public RestoreMappedDiagnosticContext() { + } + + @Override + protected void before() throws Throwable { + previousValue = MoreObjects.firstNonNull( + MDC.getCopyOfContextMap(), + ImmutableMap.of()); + } + + @Override + protected void after() { + MDC.setContextMap(previousValue); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContextTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContextTest.java new file mode 100644 index 000000000000..c88f275f4bf6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreMappedDiagnosticContextTest.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.MDC; + +/** Tests for {@link RestoreMappedDiagnosticContext}. */ +@RunWith(JUnit4.class) +public class RestoreMappedDiagnosticContextTest { + @Rule public TestRule restoreMappedDiagnosticContext = new RestoreMappedDiagnosticContext(); + + /* + * Since these tests can run out of order, both test A and B verify that they + * could insert their property and that the other does not exist. + */ + @Test + public void testThatMDCIsClearedA() { + MDC.put("TestA", "TestA"); + assertNotNull(MDC.get("TestA")); + assertNull(MDC.get("TestB")); + } + + @Test + public void testThatMDCIsClearedB() { + MDC.put("TestB", "TestB"); + assertNotNull(MDC.get("TestB")); + assertNull(MDC.get("TestA")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java new file mode 100644 index 000000000000..ef4f3427b889 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.common.base.Throwables; + +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** + * Saves and restores the current system properties for tests. + */ +public class RestoreSystemProperties extends ExternalResource implements TestRule { + private byte[] originalProperties; + + @Override + protected void before() throws Throwable { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + System.getProperties().store(baos, ""); + baos.close(); + originalProperties = baos.toByteArray(); + } + + @Override + protected void after() { + try (ByteArrayInputStream bais = new ByteArrayInputStream(originalProperties)) { + System.getProperties().clear(); + System.getProperties().load(bais); + } catch (IOException e) { + throw Throwables.propagate(e); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java new file mode 100644 index 000000000000..8a4bb488922e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RestoreSystemProperties}. */ +@RunWith(JUnit4.class) +public class RestoreSystemPropertiesTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + /* + * Since these tests can run out of order, both test A and B verify that they + * could insert their property and that the other does not exist. + */ + @Test + public void testThatPropertyIsClearedA() { + System.getProperties().put("TestA", "TestA"); + assertNotNull(System.getProperty("TestA")); + assertNull(System.getProperty("TestB")); + } + + @Test + public void testThatPropertyIsClearedB() { + System.getProperties().put("TestB", "TestB"); + assertNotNull(System.getProperty("TestB")); + assertNull(System.getProperty("TestA")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java new file mode 100644 index 000000000000..da4f66ec0775 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.common.collect.ImmutableMap; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TestPipeline}. */ +@RunWith(JUnit4.class) +public class TestPipelineTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + @Test + public void testCreationUsingDefaults() { + assertNotNull(TestPipeline.create()); + } + + @Test + public void testCreationOfPipelineOptions() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String stringOptions = mapper.writeValueAsString( + ImmutableMap.of("options", + ImmutableMap.builder() + .put("runner", DataflowPipelineRunner.class.getName()) + .put("project", "testProject") + .put("apiRootUrl", "testApiRootUrl") + .put("dataflowEndpoint", "testDataflowEndpoint") + .put("tempLocation", "testTempLocation") + .put("serviceAccountName", "testServiceAccountName") + .put("serviceAccountKeyfile", "testServiceAccountKeyfile") + .put("zone", "testZone") + .put("numWorkers", "1") + .put("diskSizeGb", "2") + .put("shuffleDiskSizeGb", "3") + .build())); + System.getProperties().put("dataflowOptions", stringOptions); + TestDataflowPipelineOptions options = TestPipeline.getPipelineOptions(); + assertEquals(DataflowPipelineRunner.class, options.getRunner()); + assertEquals("TestPipelineTest", options.getAppName()); + assertEquals("testCreationOfPipelineOptions", options.getJobName()); + assertEquals("testProject", options.getProject()); + assertEquals("testApiRootUrl", options.getApiRootUrl()); + assertEquals("testDataflowEndpoint", options.getDataflowEndpoint()); + assertEquals("testTempLocation", options.getTempLocation()); + assertEquals("testServiceAccountName", options.getServiceAccountName()); + assertEquals("testServiceAccountKeyfile", options.getServiceAccountKeyfile()); + assertEquals("testZone", options.getZone()); + assertEquals(2, options.getDiskSizeGb()); + assertEquals(3, options.getShuffleDiskSizeGb()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java new file mode 100644 index 000000000000..b0493491634c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java @@ -0,0 +1,287 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static com.google.cloud.dataflow.sdk.TestUtils.createInts; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.ApproximateQuantiles.ApproximateQuantilesCombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for ApproximateQuantiles + */ +@RunWith(JUnit4.class) +public class ApproximateQuantilesTest { + + static final List> TABLE = Arrays.asList( + KV.of("a", 1), + KV.of("a", 2), + KV.of("a", 3), + KV.of("b", 1), + KV.of("b", 10), + KV.of("b", 10), + KV.of("b", 100) + ); + + public PCollection> createInputTable(Pipeline p) { + return p.apply(Create.of(TABLE)).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + } + + @Test + public void testQuantilesGlobally() { + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = intRangeCollection(p, 101); + PCollection> quantiles = + input.apply(ApproximateQuantiles.globally(5)); + + DirectPipelineRunner.EvaluationResults results = p.run(); + + DataflowAssert.that(quantiles) + .containsInAnyOrder(Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testQuantilesGobally_comparable() { + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = intRangeCollection(p, 101); + PCollection> quantiles = + input.apply( + ApproximateQuantiles.globally(5, new DescendingIntComparator())); + + DirectPipelineRunner.EvaluationResults results = p.run(); + + DataflowAssert.that(quantiles) + .containsInAnyOrder(Arrays.asList(100, 75, 50, 25, 0)); + } + + @Test + public void testQuantilesPerKey() { + Pipeline p = TestPipeline.create(); + + PCollection> input = createInputTable(p); + PCollection>> quantiles = input.apply( + ApproximateQuantiles.perKey(2)); + + DataflowAssert.that(quantiles) + .containsInAnyOrder( + KV.of("a", Arrays.asList(1, 3)), + KV.of("b", Arrays.asList(1, 100))); + p.run(); + + } + + @Test + public void testQuantilesPerKey_reversed() { + Pipeline p = TestPipeline.create(); + + PCollection> input = createInputTable(p); + PCollection>> quantiles = input.apply( + ApproximateQuantiles.perKey( + 2, new DescendingIntComparator())); + + DataflowAssert.that(quantiles) + .containsInAnyOrder( + KV.of("a", Arrays.asList(3, 1)), + KV.of("b", Arrays.asList(100, 1))); + p.run(); + } + + @Test + public void testSingleton() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + Arrays.asList(389), + Arrays.asList(389, 389, 389, 389, 389)); + } + + @Test + public void testSimpleQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + intRange(101), + Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testUnevenQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(37), + intRange(5000), + quantileMatcher(5000, 37, 20 /* tolerance */)); + } + + @Test + public void testLargerQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(50), + intRange(10001), + quantileMatcher(10001, 50, 20 /* tolerance */)); + } + + @Test + public void testTightEpsilon() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(10).withEpsilon(0.01), + intRange(10001), + quantileMatcher(10001, 10, 5 /* tolerance */)); + } + + @Test + public void testDuplicates() { + int size = 101; + List all = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + all.addAll(intRange(size)); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testLotsOfDuplicates() { + List all = new ArrayList<>(); + all.add(1); + for (int i = 1; i < 300; i++) { + all.add(2); + } + for (int i = 300; i < 1000; i++) { + all.add(3); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(1, 2, 3, 3, 3)); + } + + @Test + public void testLogDistribution() { + List all = new ArrayList<>(); + for (int i = 1; i < 1000; i++) { + all.add((int) Math.log(i)); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(0, 5, 6, 6, 6)); + } + + @Test + public void testZipfianDistribution() { + List all = new ArrayList<>(); + for (int i = 1; i < 1000; i++) { + all.add(1000 / i); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(1, 1, 2, 4, 1000)); + } + + @Test + public void testAlternateComparator() { + List inputs = Arrays.asList( + "aa", "aaa", "aaaa", "b", "ccccc", "dddd", "zz"); + checkCombineFn( + ApproximateQuantilesCombineFn.create(3), + inputs, + Arrays.asList("aa", "b", "zz")); + checkCombineFn( + ApproximateQuantilesCombineFn.create(3, new TopTest.OrderByLength()), + inputs, + Arrays.asList("b", "aaa", "ccccc")); + } + + private Matcher> quantileMatcher( + int size, int numQuantiles, int absoluteError) { + List> quantiles = new ArrayList<>(); + quantiles.add(CoreMatchers.is(0)); + for (int k = 1; k < numQuantiles - 1; k++) { + int expected = (int) (((double) (size - 1)) * k / (numQuantiles - 1)); + quantiles.add(new Between<>( + expected - absoluteError, expected + absoluteError)); + } + quantiles.add(CoreMatchers.is(size - 1)); + return contains(quantiles); + } + + private static class Between> + extends TypeSafeDiagnosingMatcher { + private final T min; + private final T max; + private Between(T min, T max) { + this.min = min; + this.max = max; + } + @Override + public void describeTo(Description description) { + description.appendText("is between " + min + " and " + max); + } + + @Override + protected boolean matchesSafely(T item, Description mismatchDescription) { + return min.compareTo(item) <= 0 && item.compareTo(max) <= 0; + } + } + + private static class DescendingIntComparator implements + SerializableComparator { + @Override + public int compare(Integer o1, Integer o2) { + return o2.compareTo(o1); + } + } + + private PCollection intRangeCollection(Pipeline p, int size) { + return createInts(p, intRange(size)); + } + + private List intRange(int size) { + List all = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + all.add(i); + } + return all; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java new file mode 100644 index 000000000000..2b2ff0ac9c96 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java @@ -0,0 +1,302 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for the ApproximateUnique aggregator transform. + */ +@RunWith(JUnit4.class) +public class ApproximateUniqueTest { + + @Test + public void testEstimationErrorToSampleSize() { + assertEquals(40000, ApproximateUnique.sampleSizeFromEstimationError(0.01)); + assertEquals(10000, ApproximateUnique.sampleSizeFromEstimationError(0.02)); + assertEquals(2500, ApproximateUnique.sampleSizeFromEstimationError(0.04)); + assertEquals(1600, ApproximateUnique.sampleSizeFromEstimationError(0.05)); + assertEquals(400, ApproximateUnique.sampleSizeFromEstimationError(0.1)); + assertEquals(100, ApproximateUnique.sampleSizeFromEstimationError(0.2)); + assertEquals(25, ApproximateUnique.sampleSizeFromEstimationError(0.4)); + assertEquals(16, ApproximateUnique.sampleSizeFromEstimationError(0.5)); + } + + public PCollection createInput(Pipeline p, Iterable input, + Coder coder) { + return p.apply(Create.of(input)).setCoder(coder); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testApproximateUniqueWithSmallInput() { + Pipeline p = TestPipeline.create(); + + PCollection input = + createInput(p, Arrays.asList(1, 2, 3, 3), BigEndianIntegerCoder.of()); + + PCollection estimate = input + .apply(ApproximateUnique.globally(1000)); + + DataflowAssert.that(estimate).containsInAnyOrder(3L); + + p.run(); + } + + @Test + public void testApproximateUniqueWithDuplicates() { + runApproximateUniqueWithDuplicates(100, 100, 100); + runApproximateUniqueWithDuplicates(1000, 1000, 100); + runApproximateUniqueWithDuplicates(1500, 1000, 100); + runApproximateUniqueWithDuplicates(10000, 1000, 100); + } + + private void runApproximateUniqueWithDuplicates(int elementCount, + int uniqueCount, int sampleSize) { + + assert elementCount >= uniqueCount; + List elements = Lists.newArrayList(); + for (int i = 0; i < elementCount; i++) { + elements.add(1.0 / (i % uniqueCount + 1)); + } + Collections.shuffle(elements); + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = createInput(p, elements, DoubleCoder.of()); + PCollection estimate = + input.apply(ApproximateUnique.globally(sampleSize)); + + EvaluationResults results = p.run(); + + verifyEstimate(uniqueCount, sampleSize, + results.getPCollection(estimate).get(0)); + } + + @Test + public void testApproximateUniqueWithSkewedDistributions() { + runApproximateUniqueWithSkewedDistributions(100, 100, 100); + runApproximateUniqueWithSkewedDistributions(10000, 10000, 100); + runApproximateUniqueWithSkewedDistributions(10000, 1000, 100); + runApproximateUniqueWithSkewedDistributions(10000, 200, 100); + } + + @Test + public void testApproximateUniqueWithSkewedDistributionsAndLargeSampleSize() { + runApproximateUniqueWithSkewedDistributions(10000, 2000, 1000); + } + + private void runApproximateUniqueWithSkewedDistributions(int elementCount, + final int uniqueCount, final int sampleSize) { + List elements = Lists.newArrayList(); + // Zipf distribution with approximately elementCount items. + double s = 1 - 1.0 * uniqueCount / elementCount; + double maxCount = Math.pow(uniqueCount, s); + for (int k = 0; k < uniqueCount; k++) { + int count = Math.max(1, (int) Math.round(maxCount * Math.pow(k, -s))); + // Element k occurs count times. + for (int c = 0; c < count; c++) { + elements.add(k); + } + } + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + createInput(p, elements, BigEndianIntegerCoder.of()); + PCollection estimate = + input.apply(ApproximateUnique.globally(sampleSize)); + + EvaluationResults results = p.run(); + + verifyEstimate(uniqueCount, sampleSize, + results.getPCollection(estimate).get(0).longValue()); + } + + @Test + public void testApproximateUniquePerKey() { + List> elements = Lists.newArrayList(); + List keys = ImmutableList.of(20, 50, 100); + int elementCount = 1000; + int sampleSize = 100; + // Use the key as the number of unique values. + for (int uniqueCount : keys) { + for (int value = 0; value < elementCount; value++) { + elements.add(KV.of(uniqueCount, value % uniqueCount)); + } + } + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection> input = createInput(p, elements, + KvCoder.of(BigEndianIntegerCoder.of(), BigEndianIntegerCoder.of())); + PCollection> counts = + input.apply(ApproximateUnique.perKey(sampleSize)); + + EvaluationResults results = p.run(); + + for (KV result : results.getPCollection(counts)) { + verifyEstimate(result.getKey(), sampleSize, result.getValue()); + } + } + + /** + * Applies {@link ApproximateUnique} for different sample sizes and verifies + * that the estimation error falls within the maximum allowed error of + * {@code 2 / sqrt(sampleSize)}. + */ + @Test + public void testApproximateUniqueWithDifferentSampleSizes() { + runApproximateUniquePipeline(16); + runApproximateUniquePipeline(64); + runApproximateUniquePipeline(128); + runApproximateUniquePipeline(256); + runApproximateUniquePipeline(512); + runApproximateUniquePipeline(1000); + runApproximateUniquePipeline(1024); + try { + runApproximateUniquePipeline(15); + fail("Accepted sampleSize < 16"); + } catch (IllegalArgumentException e) { + assertTrue("Expected an exception due to sampleSize < 16", e.getMessage() + .startsWith("ApproximateUnique needs a sampleSize >= 16")); + } + } + + /** + * Applies {@code ApproximateUnique(sampleSize)} verifying that the estimation + * error falls within the maximum allowed error of {@code 2/sqrt(sampleSize)}. + */ + private void runApproximateUniquePipeline(int sampleSize) { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection collection = readPCollection(p); + + PCollection exact = collection.apply(RemoveDuplicates.create()) + .apply(Combine.globally(new CountElements())); + + PCollection approximate = + collection.apply(ApproximateUnique.globally(sampleSize)); + + EvaluationResults results = p.run(); + + verifyEstimate(results.getPCollection(exact).get(0).longValue(), sampleSize, + results.getPCollection(approximate).get(0).longValue()); + } + + /** + * Reads a large {@code PCollection}. + */ + private PCollection readPCollection(Pipeline p) { + // TODO: Read PCollection from a set of text files. + List page = TestUtils.LINES; + final int pages = 1000; + ArrayList file = new ArrayList<>(pages * page.size()); + for (int i = 0; i < pages; i++) { + file.addAll(page); + } + assert file.size() == pages * page.size(); + PCollection words = TestUtils.createStrings(p, file); + return words; + } + + /** + * Checks that the estimation error, i.e., the difference between + * {@code uniqueCount} and {@code estimate} is less than + * {@code 2 / sqrt(sampleSize}). + */ + private static void verifyEstimate(long uniqueCount, int sampleSize, + long estimate) { + if (uniqueCount < sampleSize) { + assertEquals("Number of hashes is less than the sample size. " + + "Estimate should be exact", uniqueCount, estimate); + } + + double error = 100.0 * Math.abs(estimate - uniqueCount) / uniqueCount; + double maxError = 100.0 * 2 / Math.sqrt(sampleSize); + + assertTrue("Estimate= " + estimate + " Actual=" + uniqueCount + " Error=" + + error + "%, MaxError=" + maxError + "%.", error < maxError); + } + + /** + * Combiner function counting the number of elements in an input PCollection. + * + * @param the type of elements in the input PCollection. + */ + private static class CountElements extends CombineFn { + + @Override + public Long[] createAccumulator() { + Long[] accumulator = new Long[1]; + accumulator[0] = 0L; + return accumulator; + } + + @Override + public void addInput(Long[] accumulator, E input) { + accumulator[0]++; + } + + @Override + public Long[] mergeAccumulators(Iterable accumulators) { + Long[] sum = new Long[1]; + sum[0] = 0L; + for (Long[] accumulator : accumulators) { + sum[0] += accumulator[0]; + } + return sum; + } + + @Override + public Long extractOutput(Long[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, + Coder inputCoder) { + return SerializableCoder.of(Long[].class); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java new file mode 100644 index 000000000000..52b0b230a19d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java @@ -0,0 +1,527 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertThat; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.RecordingPipelineVisitor; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; + +/** + * Tests for Combine transforms. + */ +@RunWith(JUnit4.class) +public class CombineTest { + + @SuppressWarnings("unchecked") + static final KV[] TABLE = new KV[] { + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13), + }; + + @SuppressWarnings("unchecked") + static final KV[] EMPTY_TABLE = new KV[] { + }; + + static final Integer[] NUMBERS = new Integer[] { + 1, 1, 2, 3, 5, 8, 13, 21, 34, 55 + }; + + PCollection> createInput(Pipeline p, + KV[] table) { + return p.apply(Create.of(Arrays.asList(table))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + } + + private void runTestSimpleCombine(KV[] table, + int globalSum, + KV[] perKeySums) { + Pipeline p = TestPipeline.create(); + PCollection> input = createInput(p, table); + + PCollection sum = input + .apply(Values.create()) + .apply(Combine.globally(new SumInts())); + + // Java 8 will infer. + PCollection> sumPerKey = input + .apply(Combine.perKey(new SumInts())); + + DataflowAssert.that(sum).containsInAnyOrder(globalSum); + DataflowAssert.that(sumPerKey).containsInAnyOrder(perKeySums); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSimpleCombine() { + runTestSimpleCombine(TABLE, 20, new KV[] { + KV.of("a", 6), KV.of("b", 14) }); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSimpleCombineEmpty() { + runTestSimpleCombine(EMPTY_TABLE, 0, new KV[] { }); + } + + private void runTestBasicCombine(KV[] table, + Set globalUnique, + KV>[] perKeyUnique) { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(Set.class, SetCoder.class); + PCollection> input = createInput(p, table); + + PCollection> unique = input + .apply(Values.create()) + .apply(Combine.globally(new UniqueInts())); + + // Java 8 will infer. + PCollection>> uniquePerKey = input + .apply(Combine.>perKey(new UniqueInts())); + + DataflowAssert.that(unique).containsInAnyOrder(globalUnique); + DataflowAssert.that(uniquePerKey).containsInAnyOrder(perKeyUnique); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testBasicCombine() { + runTestBasicCombine(TABLE, ImmutableSet.of(1, 13, 4), new KV[] { + KV.of("a", (Set) ImmutableSet.of(1, 4)), + KV.of("b", (Set) ImmutableSet.of(1, 13)) }); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testBasicCombineEmpty() { + runTestBasicCombine(EMPTY_TABLE, ImmutableSet.of(), new KV[] { }); + } + + private void runTestAccumulatingCombine(KV[] table, + Double globalMean, + KV[] perKeyMeans) { + Pipeline p = TestPipeline.create(); + PCollection> input = createInput(p, table); + + PCollection mean = input + .apply(Values.create()) + .apply(Combine.globally(new MeanInts())); + + // Java 8 will infer. + PCollection> meanPerKey = input.apply( + Combine.perKey(new MeanInts())); + + DataflowAssert.that(mean).containsInAnyOrder(globalMean); + DataflowAssert.that(meanPerKey).containsInAnyOrder(perKeyMeans); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testWindowedCombineEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection mean = p + .apply(Create.of()).setCoder(BigEndianIntegerCoder.of()) + .apply(Window.into(FixedWindows.of(Duration.millis(1)))) + .apply(Combine.globally(new MeanInts())); + + DataflowAssert.that(mean).containsInAnyOrder(); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testAccumulatingCombine() { + runTestAccumulatingCombine(TABLE, 4.0, new KV[] { + KV.of("a", 2.0), KV.of("b", 7.0) }); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testAccumulatingCombineEmpty() { + runTestAccumulatingCombine(EMPTY_TABLE, 0.0, new KV[] { }); + } + + // Checks that Min, Max, Mean, Sum (operations which pass-through to Combine), + // provide their own top-level name. + @Test + public void testCombinerNames() { + Pipeline p = TestPipeline.create(); + PCollection> input = createInput(p, TABLE); + + Combine.PerKey min = Min.integersPerKey(); + Combine.PerKey max = Max.integersPerKey(); + Combine.PerKey mean = Mean.perKey(); + Combine.PerKey sum = Sum.integersPerKey(); + + input.apply(min); + input.apply(max); + input.apply(mean); + input.apply(sum); + + p.traverseTopologically(new RecordingPipelineVisitor()); + + assertThat(p.getFullName(min), Matchers.startsWith("Min")); + assertThat(p.getFullName(max), Matchers.startsWith("Max")); + assertThat(p.getFullName(mean), Matchers.startsWith("Mean")); + assertThat(p.getFullName(sum), Matchers.startsWith("Sum")); + } + + @Test + public void testAddInputsRandomly() { + TestCounter counter = new TestCounter(); + Combine.KeyedCombineFn< + String, Integer, TestCounter.Counter, Iterable> fn = + counter.asKeyedFn(); + + List accums = DirectPipelineRunner.TestCombineDoFn.addInputsRandomly( + fn, "bob", Arrays.asList(NUMBERS), new Random(42)); + + assertThat(accums, Matchers.contains( + counter.new Counter(3, 2, 0, 0), + counter.new Counter(131, 5, 0, 0), + counter.new Counter(8, 2, 0, 0), + counter.new Counter(1, 1, 0, 0))); + } + + //////////////////////////////////////////////////////////////////////////// + // Test classes, for different kinds of combining fns. + + /** Example SerializableFunction combiner. */ + public static class SumInts + implements SerializableFunction, Integer> { + @Override + public Integer apply(Iterable input) { + int sum = 0; + for (int item : input) { + sum += item; + } + return sum; + } + } + + /** Example CombineFn. */ + public static class UniqueInts extends + Combine.CombineFn, Set> { + + @Override + public Set createAccumulator() { + return new HashSet<>(); + } + + @Override + public void addInput(Set accumulator, Integer input) { + accumulator.add(input); + } + + @Override + public Set mergeAccumulators(Iterable> accumulators) { + Set all = new HashSet<>(); + for (Set part : accumulators) { + all.addAll(part); + } + return all; + } + + @Override + public Set extractOutput(Set accumulator) { + return accumulator; + } + } + + // Note: not a deterministic encoding + private static class SetCoder extends StandardCoder> { + + public static SetCoder of(Coder elementCoder) { + return new SetCoder<>(elementCoder); + } + + @JsonCreator + public static SetCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + public static List getInstanceComponents(Set exampleValue) { + return IterableCoder.getInstanceComponents(exampleValue); + } + + private final Coder> iterableCoder; + + private SetCoder(Coder elementCoder) { + iterableCoder = IterableCoder.of(elementCoder); + } + + @Override + public void encode(Set value, OutputStream outStream, Context context) + throws CoderException, IOException { + iterableCoder.encode(value, outStream, context); + } + + @Override + public Set decode(InputStream inStream, Context context) + throws CoderException, IOException { + // TODO: Eliminate extra copy if used in production. + return Sets.newHashSet(iterableCoder.decode(inStream, context)); + } + + @Override + public List> getCoderArguments() { + return iterableCoder.getCoderArguments(); + } + + @Override + public boolean isDeterministic() { + return false; + } + + @Override + public boolean isRegisterByteSizeObserverCheap(Set value, Context context) { + return iterableCoder.isRegisterByteSizeObserverCheap(value, context); + } + + @Override + public void registerByteSizeObserver( + Set value, ElementByteSizeObserver observer, Context context) + throws Exception { + iterableCoder.registerByteSizeObserver(value, observer, context); + } + } + + /** Example AccumulatingCombineFn. */ + public static class MeanInts extends + Combine.AccumulatingCombineFn { + private static final Coder LONG_CODER = BigEndianLongCoder.of(); + private static final Coder DOUBLE_CODER = DoubleCoder.of(); + + class CountSum extends + Combine.AccumulatingCombineFn.Accumulator { + long count = 0; + double sum = 0.0; + + CountSum(long count, double sum) { + this.count = count; + this.sum = sum; + } + + @Override + public void addInput(Integer element) { + count++; + sum += element.doubleValue(); + } + + @Override + public void mergeAccumulator(CountSum accumulator) { + count += accumulator.count; + sum += accumulator.sum; + } + + @Override + public Double extractOutput() { + return count == 0 ? 0.0 : sum / count; + } + } + + @Override + public CountSum createAccumulator() { + return new CountSum(0, 0.0); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new CountSumCoder(); + } + + /** + * A Coder for CountSum + */ + public class CountSumCoder extends CustomCoder { + @Override + public void encode(CountSum value, OutputStream outStream, + Context context) throws CoderException, IOException { + LONG_CODER.encode(value.count, outStream, context); + DOUBLE_CODER.encode(value.sum, outStream, context); + } + + @Override + public CountSum decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + long count = LONG_CODER.decode(inStream, context); + double sum = DOUBLE_CODER.decode(inStream, context); + return new CountSum(count, sum); + } + + @Override + public boolean isDeterministic() { + return true; + } + + @Override + public boolean isRegisterByteSizeObserverCheap( + CountSum value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + CountSum value, ElementByteSizeObserver observer, Context context) + throws Exception { + LONG_CODER.registerByteSizeObserver(value.count, observer, context); + DOUBLE_CODER.registerByteSizeObserver(value.sum, observer, context); + } + } + } + + /** Another example AccumulatingCombineFn. */ + public static class TestCounter extends + Combine.AccumulatingCombineFn< + Integer, TestCounter.Counter, Iterable> { + + /** An accumulator that observes its merges and outputs */ + public class Counter extends + Combine.AccumulatingCombineFn< + Integer, Counter, Iterable>.Accumulator { + + public long sum = 0; + public long inputs = 0; + public long merges = 0; + public long outputs = 0; + + public Counter(long sum, long inputs, long merges, long outputs) { + this.sum = sum; + this.inputs = inputs; + this.merges = merges; + this.outputs = outputs; + } + + @Override + public void addInput(Integer element) { + Preconditions.checkState(merges == 0); + Preconditions.checkState(outputs == 0); + + inputs++; + sum += element; + } + + @Override + public void mergeAccumulator(Counter accumulator) { + Preconditions.checkState(outputs == 0); + Preconditions.checkArgument(accumulator.outputs == 0); + + merges += accumulator.merges + 1; + inputs += accumulator.inputs; + sum += accumulator.sum; + } + + @Override + public Iterable extractOutput() { + Preconditions.checkState(outputs == 0); + + return Arrays.asList(sum, inputs, merges, outputs); + } + + @Override + public int hashCode() { + return (int) (sum * 17 + inputs * 31 + merges * 43 + outputs * 181); + } + + @Override + public boolean equals(Object otherObj) { + if (otherObj instanceof Counter) { + Counter other = (Counter) otherObj; + return (sum == other.sum + && inputs == other.inputs + && merges == other.merges + && outputs == other.outputs); + } + return false; + } + + public String toString() { + return sum + ":" + inputs + ":" + merges + ":" + outputs; + } + } + + @Override + public Counter createAccumulator() { + return new Counter(0, 0, 0, 0); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return SerializableCoder.of(Counter.class); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java new file mode 100644 index 000000000000..05375bd7c536 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.createStrings; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for Count. + */ +@RunWith(JUnit4.class) +public class CountTest { + static final String[] WORDS_ARRAY = new String[] { + "hi", "there", "hi", "hi", "sue", "bob", + "hi", "sue", "", "", "ZOW", "bob", "" }; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCountPerElementBasic() { + Pipeline p = TestPipeline.create(); + + PCollection input = createStrings(p, WORDS); + + PCollection> output = + input.apply(Count.perElement()); + + DataflowAssert.that(output) + .containsInAnyOrder( + KV.of("hi", 4L), + KV.of("there", 1L), + KV.of("sue", 2L), + KV.of("bob", 2L), + KV.of("", 3L), + KV.of("ZOW", 1L)); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCountPerElementEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = createStrings(p, NO_LINES); + + PCollection> output = + input.apply(Count.perElement()); + + DataflowAssert.that(output) + .containsInAnyOrder(); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCountGloballyBasic() { + Pipeline p = TestPipeline.create(); + + PCollection input = createStrings(p, WORDS); + + PCollection output = + input.apply(Count.globally()); + + DataflowAssert.that(output) + .containsInAnyOrder(13L); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCountGloballyEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = createStrings(p, NO_LINES); + + PCollection output = + input.apply(Count.globally()); + + DataflowAssert.that(output) + .containsInAnyOrder(0L); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java new file mode 100644 index 000000000000..8202da086240 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for Create. + */ +@RunWith(JUnit4.class) +public class CreateTest { + @Rule public final ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCreate() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(LINES)); + + DataflowAssert.that(output) + .containsInAnyOrder(LINES_ARRAY); + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testCreateOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(LINES)) + .setOrdered(true); + + DataflowAssert.that(output) + .containsInOrder(LINES_ARRAY); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCreateEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(NO_LINES)) + .setCoder(StringUtf8Coder.of()); + + DataflowAssert.that(output) + .containsInAnyOrder(NO_LINES_ARRAY); + p.run(); + } + + static class Record implements Serializable { + } + + static class Record2 extends Record { + } + + @Test + public void testPolymorphicType() throws Exception { + thrown.expect(RuntimeException.class); + thrown.expectMessage( + Matchers.containsString("unable to infer a default Coder")); + + Pipeline p = TestPipeline.create(); + + // Create won't infer a default coder in this case. + p.apply(Create.of(new Record(), new Record2())); + + p.run(); + } + + @Test + public void testCreateParameterizedType() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection> output = + p.apply(Create.of( + TimestampedValue.of("a", new Instant(0)), + TimestampedValue.of("b", new Instant(0)))); + + DataflowAssert.that(output) + .containsInAnyOrder( + TimestampedValue.of("a", new Instant(0)), + TimestampedValue.of("b", new Instant(0))); + } + + private static class PrintTimestamps extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + ":" + c.timestamp().getMillis()); + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCreateTimestamped() { + Pipeline p = TestPipeline.create(); + + List> data = Arrays.asList( + TimestampedValue.of("a", new Instant(1L)), + TimestampedValue.of("b", new Instant(2L)), + TimestampedValue.of("c", new Instant(3L))); + + PCollection output = + p.apply(Create.timestamped(data)) + .apply(ParDo.of(new PrintTimestamps())); + + DataflowAssert.that(output) + .containsInAnyOrder("a:1", "b:2", "c:3"); + p.run(); + } + + @Test + // This test fails when run on the service! + // TODO: @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testCreateTimestampedEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.timestamped(new ArrayList>())) + .setCoder(StringUtf8Coder.of()); + + DataflowAssert.that(output) + .containsInAnyOrder(); + p.run(); + } + + @Test + public void testCreateTimestampedPolymorphicType() throws Exception { + thrown.expect(RuntimeException.class); + thrown.expectMessage( + Matchers.containsString("unable to infer a default Coder")); + + Pipeline p = TestPipeline.create(); + + // Create won't infer a default coder in this case. + PCollection c = p.apply(Create.timestamped( + TimestampedValue.of(new Record(), new Instant(0)), + TimestampedValue.of(new Record2(), new Instant(0)))); + + p.run(); + + + throw new RuntimeException("Coder: " + c.getCoder()); + + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FirstTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FirstTest.java new file mode 100644 index 000000000000..bcd14d0c6e89 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FirstTest.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +/** + * Tests for First. + */ +@RunWith(JUnit4.class) +public class FirstTest + implements Serializable /* to allow anon inner classes */ { + // PRE: lines contains no duplicates. + void runTestFirst(final List lines, int limit, boolean ordered) { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(lines)) + .setCoder(StringUtf8Coder.of()); + + if (ordered) { + input.setOrdered(true); + } + + PCollection output = + input.apply(First.of(limit)); + + if (ordered) { + output.setOrdered(true); + } + + final int expectedSize = Math.min(limit, lines.size()); + if (ordered) { + List expected = lines.subList(0, expectedSize); + if (expected.isEmpty()) { + DataflowAssert.that(output) + .containsInAnyOrder(expected); + } else { + DataflowAssert.that(output) + .containsInOrder(expected); + } + } else { + DataflowAssert.that(output) + .satisfies(new SerializableFunction, Void>() { + @Override + public Void apply(Iterable actualIter) { + // Make sure actual is the right length, and is a + // subset of expected. + List actual = new ArrayList<>(); + for (String s : actualIter) { + actual.add(s); + } + assertEquals(expectedSize, actual.size()); + Set actualAsSet = new TreeSet<>(actual); + Set linesAsSet = new TreeSet<>(lines); + assertEquals(actual.size(), actualAsSet.size()); + assertEquals(lines.size(), linesAsSet.size()); + assertTrue(linesAsSet.containsAll(actualAsSet)); + return null; + } + }); + } + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFirst() { + runTestFirst(LINES, 0, false); + runTestFirst(LINES, LINES.size() / 2, false); + runTestFirst(LINES, LINES.size() * 2, false); + } + + @Test + // Extra tests, not worth the time to run on the real service. + public void testFirstMore() { + runTestFirst(LINES, LINES.size() - 1, false); + runTestFirst(LINES, LINES.size(), false); + runTestFirst(LINES, LINES.size() + 1, false); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testFirstOrdered() { + runTestFirst(LINES, 0, true); + runTestFirst(LINES, LINES.size() / 2, true); + runTestFirst(LINES, LINES.size() - 1, true); + runTestFirst(LINES, LINES.size(), true); + runTestFirst(LINES, LINES.size() + 1, true); + runTestFirst(LINES, LINES.size() * 2, true); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFirstEmpty() { + runTestFirst(NO_LINES, 0, false); + runTestFirst(NO_LINES, 1, false); + } + + @Test + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + public void testFirstEmptyOrdered() { + runTestFirst(NO_LINES, 0, true); + runTestFirst(NO_LINES, 1, true); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java new file mode 100644 index 000000000000..70cc4f1eaf88 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES2; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.joda.time.Duration; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for Flatten. + */ +@RunWith(JUnit4.class) +public class FlattenTest { + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFlattenPCollectionListUnordered() { + Pipeline p = TestPipeline.create(); + + List[] inputs = new List[] { + LINES, NO_LINES, LINES2, NO_LINES, LINES, NO_LINES }; + + PCollection output = + makePCollectionListOfStrings(false /* not ordered */, p, inputs) + .apply(Flatten.pCollections()); + + DataflowAssert.that(output).containsInAnyOrder(flatten(inputs)); + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testFlattenPCollectionListOrdered() { + Pipeline p = TestPipeline.create(); + + List[] inputs = new List[] { + LINES, NO_LINES, LINES2, NO_LINES, LINES, NO_LINES }; + + PCollection output = + makePCollectionListOfStrings(true /* ordered */, p, inputs) + .apply(Flatten.pCollections()).setOrdered(true); + + DataflowAssert.that(output).containsInOrder(flatten(inputs)); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFlattenPCollectionListEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = + PCollectionList.empty(p) + .apply(Flatten.pCollections()).setCoder(StringUtf8Coder.of()); + + DataflowAssert.that(output).containsInAnyOrder(); + p.run(); + } + + @Test + public void testWindowingFnPropagationFailure() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply(Create.of("Input1")) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply(Create.of("Input2")) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(2)))); + + try { + PCollection output = + PCollectionList.of(input1).and(input2) + .apply(Flatten.create()); + Assert.fail("Exception should have been thrown"); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().startsWith( + "Inputs to Flatten had incompatible window windowingFns")); + } + } + + @Test + public void testWindowingFnPropagation() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply(Create.of("Input1")) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply(Create.of("Input2")) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + + PCollection output = + PCollectionList.of(input1).and(input2) + .apply(Flatten.create()); + + p.run(); + + Assert.assertTrue(output.getWindowingFn().isCompatible( + FixedWindows.of(Duration.standardMinutes(1)))); + } + + @Test + public void testEqualWindowingFnPropagation() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply(Create.of("Input1")) + .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply(Create.of("Input2")) + .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(2)))); + + PCollection output = + PCollectionList.of(input1).and(input2) + .apply(Flatten.create()); + + p.run(); + + Assert.assertTrue(output.getWindowingFn().isCompatible( + Sessions.withGapDuration(Duration.standardMinutes(2)))); + } + + + PCollectionList makePCollectionListOfStrings(boolean ordered, + Pipeline p, + List... lists) { + return makePCollectionList(ordered, p, StringUtf8Coder.of(), lists); + } + + PCollectionList makePCollectionList(boolean ordered, + Pipeline p, + Coder coder, + List... lists) { + List> pcs = new ArrayList<>(); + for (List list : lists) { + PCollection pc = p.apply(Create.of(list)).setCoder(coder); + if (ordered) { + pc.setOrdered(true); + } + pcs.add(pc); + } + return PCollectionList.of(pcs); + } + + T[] flatten(List... lists) { + List flattened = new ArrayList<>(); + for (List list : lists) { + flattened.addAll(list); + } + return flattened.toArray((T[]) new Object[flattened.size()]); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFlattenIterables() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.>of(LINES)) + .setCoder(IterableCoder.of(StringUtf8Coder.of())); + + PCollection output = + input.apply(Flatten.iterables()); + + DataflowAssert.that(output) + .containsInAnyOrder(LINES_ARRAY); + + p.run(); + } + + @Test + public void testFlattenIterablesOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.>of(LINES)) + .setCoder(IterableCoder.of(StringUtf8Coder.of())); + + PCollection output = + input.apply(Flatten.iterables()).setOrdered(true); + + DataflowAssert.that(output) + .containsInOrder(LINES_ARRAY); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testFlattenIterablesEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.>of(NO_LINES)) + .setCoder(IterableCoder.of(StringUtf8Coder.of())); + + PCollection output = + input.apply(Flatten.iterables()); + + DataflowAssert.that(output) + .containsInAnyOrder(NO_LINES_ARRAY); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java new file mode 100644 index 000000000000..ebb141f38b7b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java @@ -0,0 +1,280 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.KvMatcher.isKv; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindowingFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Tests for GroupByKey. + */ +@RunWith(JUnit4.class) +public class GroupByKeyTest { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testGroupByKey() { + List> ungroupedPairs = Arrays.asList( + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of("k2", -33), + KV.of("k3", 0)); + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection>> output = + input.apply(GroupByKey.create()); + + DataflowAssert.that(output) + .satisfies(new AssertThatHasExpectedContentsForTestGroupByKey()); + + p.run(); + } + + static class AssertThatHasExpectedContentsForTestGroupByKey + implements SerializableFunction>>, + Void> { + @Override + public Void apply(Iterable>> actual) { + assertThat(actual, containsInAnyOrder( + isKv(is("k1"), containsInAnyOrder(3, 4)), + isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, + Integer.MIN_VALUE)), + isKv(is("k2"), containsInAnyOrder(66, -33)), + isKv(is("k3"), containsInAnyOrder(0)))); + return null; + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testGroupByKeyAndWindows() { + List> ungroupedPairs = Arrays.asList( + KV.of("k1", 3), // window [0, 5) + KV.of("k5", Integer.MAX_VALUE), // window [0, 5) + KV.of("k5", Integer.MIN_VALUE), // window [0, 5) + KV.of("k2", 66), // window [0, 5) + KV.of("k1", 4), // window [5, 10) + KV.of("k2", -33), // window [5, 10) + KV.of("k3", 0)); // window [5, 10) + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.timestamped(ungroupedPairs, Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L))) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + PCollection>> output = + input.apply(Window.>into(FixedWindows.of(new Duration(5)))) + .apply(GroupByKey.create()); + + DataflowAssert.that(output) + .satisfies(new AssertThatHasExpectedContentsForTestGroupByKeyAndWindows()); + + p.run(); + } + + static class AssertThatHasExpectedContentsForTestGroupByKeyAndWindows + implements SerializableFunction>>, + Void> { + @Override + public Void apply(Iterable>> actual) { + assertThat(actual, containsInAnyOrder( + isKv(is("k1"), containsInAnyOrder(3)), + isKv(is("k1"), containsInAnyOrder(4)), + isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, + Integer.MIN_VALUE)), + isKv(is("k2"), containsInAnyOrder(66)), + isKv(is("k2"), containsInAnyOrder(-33)), + isKv(is("k3"), containsInAnyOrder(0)))); + return null; + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testGroupByKeyEmpty() { + List> ungroupedPairs = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection>> output = + input.apply(GroupByKey.create()); + + DataflowAssert.that(output) + .containsInAnyOrder(); + + p.run(); + } + + @Test + public void testGroupByKeyNonDeterministic() throws Exception { + expectedEx.expect(IllegalStateException.class); + expectedEx.expectMessage(Matchers.containsString("must be deterministic")); + + List, Integer>> ungroupedPairs = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection, Integer>> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder( + KvCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), + BigEndianIntegerCoder.of())); + + input.apply(GroupByKey., Integer>create()); + + p.run(); + } + + @Test + public void testIdentityWindowingFnPropagation() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(1)))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + p.run(); + + Assert.assertTrue(output.getWindowingFn().isCompatible( + FixedWindows.>of(Duration.standardMinutes(1)))); + + } + + @Test + public void testWindowingFnInvalidation() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + p.run(); + + Assert.assertTrue( + output.getWindowingFn().isCompatible( + new InvalidWindowingFn( + "Invalid", + Sessions.>withGapDuration( + Duration.standardMinutes(1))))); + } + + @Test + public void testInvalidWindowingFn() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + try { + PCollection>>> output = input + .apply(GroupByKey.create()) + .apply(GroupByKey.>create()); + Assert.fail("Exception should have been thrown"); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().startsWith( + "GroupByKey must have a valid Window merge function.")); + } + } + + @Test + public void testRemerge() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + PCollection>>> middle = input + .apply(GroupByKey.create()) + .apply(Window.>>remerge()) + .apply(GroupByKey.>create()) + .apply(Window.>>>remerge()); + + p.run(); + + Assert.assertTrue( + middle.getWindowingFn().isCompatible( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java new file mode 100644 index 000000000000..1d6e233adef8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for Keys transform. + */ +@RunWith(JUnit4.class) +public class KeysTest { + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("dup", 4), + KV.of("dup", 5) + }; + + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testKeys() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection output = input.apply(Keys.create()); + DataflowAssert.that(output) + .containsInAnyOrder("one", "two", "three", "dup", "dup"); + + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testKeysOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + input.setOrdered(true); + PCollection output = + input.apply(Keys.create()).setOrdered(true); + DataflowAssert.that(output) + .containsInOrder("one", "two", "three", "dup", "dup"); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testKeysEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection output = input.apply(Keys.create()); + DataflowAssert.that(output) + .containsInAnyOrder(); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java new file mode 100644 index 000000000000..15c2ff2ff736 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for KvSwap transform. + */ +@RunWith(JUnit4.class) +public class KvSwapTest { + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("four", 4), + KV.of("dup", 4), + KV.of("dup", 5) + }; + + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testKvSwap() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection> output = input.apply( + KvSwap.create()); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of(1, "one"), + KV.of(2, "two"), + KV.of(3, "three"), + KV.of(4, "four"), + KV.of(4, "dup"), + KV.of(5, "dup")); + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testKvSwapOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + input.setOrdered(true); + PCollection> output = input.apply( + KvSwap.create()).setOrdered(true); + + DataflowAssert.that(output).containsInOrder( + KV.of(1, "one"), + KV.of(2, "two"), + KV.of(3, "three"), + KV.of(4, "four"), + KV.of(4, "dup"), + KV.of(5, "dup")); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testKvSwapEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection> output = input.apply( + KvSwap.create()); + + DataflowAssert.that(output).containsInAnyOrder(); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java new file mode 100644 index 000000000000..7e46bb31a785 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java @@ -0,0 +1,986 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.createInts; +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.api.client.util.Preconditions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for ParDo. + */ +@RunWith(JUnit4.class) +public class ParDoTest implements Serializable { + // This test is Serializable, just so that it's easy to have + // anonymous inner classes inside the non-static test methods. + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + + static class TestDoFn extends DoFn { + enum State { UNSTARTED, STARTED, PROCESSING, FINISHED } + State state = State.UNSTARTED; + + final List> sideInputViews = new ArrayList<>(); + final List> sideOutputTupleTags = new ArrayList<>(); + + public TestDoFn() { + } + + public TestDoFn(List> sideInputViews, + List> sideOutputTupleTags) { + this.sideInputViews.addAll(sideInputViews); + this.sideOutputTupleTags.addAll(sideOutputTupleTags); + } + + @Override + public void startBundle(Context c) { + assertEquals(State.UNSTARTED, state); + state = State.STARTED; + outputToAll(c, "started"); + } + + @Override + public void processElement(ProcessContext c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.PROCESSING; + outputToAll(c, "processing: " + c.element()); + } + + @Override + public void finishBundle(Context c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.FINISHED; + outputToAll(c, "finished"); + } + + private void outputToAll(Context c, String value) { + if (!sideInputViews.isEmpty()) { + List sideInputValues = new ArrayList<>(); + for (PCollectionView sideInputView : sideInputViews) { + sideInputValues.add(c.sideInput(sideInputView)); + } + value += ": " + sideInputValues; + } + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + + /** DataflowAssert "matcher" for expected output. */ + static class HasExpectedOutput + implements SerializableFunction, Void>, Serializable { + private final List inputs; + private final List sideInputs; + private final String sideOutput; + private final boolean ordered; + + public static HasExpectedOutput forInput(List inputs) { + return new HasExpectedOutput( + new ArrayList(inputs), + new ArrayList(), + "", + false); + } + + private HasExpectedOutput(List inputs, + List sideInputs, + String sideOutput, + boolean ordered) { + this.inputs = inputs; + this.sideInputs = sideInputs; + this.sideOutput = sideOutput; + this.ordered = ordered; + } + + public HasExpectedOutput andSideInputs(Integer... sideInputValues) { + List sideInputs = new ArrayList<>(); + for (Integer sideInputValue : sideInputValues) { + sideInputs.add(sideInputValue); + } + return new HasExpectedOutput(inputs, sideInputs, sideOutput, ordered); + } + + public HasExpectedOutput fromSideOutput(TupleTag sideOutputTag) { + return fromSideOutput(sideOutputTag.getId()); + } + public HasExpectedOutput fromSideOutput(String sideOutput) { + return new HasExpectedOutput(inputs, sideInputs, sideOutput, ordered); + } + + public HasExpectedOutput inOrder() { + return new HasExpectedOutput(inputs, sideInputs, sideOutput, true); + } + + @Override + public Void apply(Iterable outputs) { + List starteds = new ArrayList<>(); + List processeds = new ArrayList<>(); + List finisheds = new ArrayList<>(); + for (String output : outputs) { + if (output.contains("started")) { + starteds.add(output); + } else if (output.contains("finished")) { + finisheds.add(output); + } else { + processeds.add(output); + } + } + + String sideInputsSuffix; + if (sideInputs.isEmpty()) { + sideInputsSuffix = ""; + } else { + sideInputsSuffix = ": " + sideInputs; + } + + String sideOutputPrefix; + if (sideOutput.isEmpty()) { + sideOutputPrefix = ""; + } else { + sideOutputPrefix = sideOutput + ": "; + } + + List expectedProcesseds = new ArrayList<>(); + for (Integer input : inputs) { + expectedProcesseds.add( + sideOutputPrefix + "processing: " + input + sideInputsSuffix); + } + String[] expectedProcessedsArray = + expectedProcesseds.toArray(new String[expectedProcesseds.size()]); + if (!ordered || expectedProcesseds.isEmpty()) { + assertThat(processeds, containsInAnyOrder(expectedProcessedsArray)); + } else { + assertThat(processeds, contains(expectedProcessedsArray)); + } + + assertEquals(starteds.size(), finisheds.size()); + assertTrue(starteds.size() > 0); + for (String started : starteds) { + assertEquals(sideOutputPrefix + "started" + sideInputsSuffix, + started); + } + for (String finished : finisheds) { + assertEquals(sideOutputPrefix + "finished" + sideInputsSuffix, + finished); + } + + return null; + } + } + } + + static class TestStartBatchErrorDoFn extends DoFn { + @Override + public void startBundle(Context c) { + throw new RuntimeException("test error in initialize"); + } + + @Override + public void processElement(ProcessContext c) { + // This has to be here. + } + } + + static class TestProcessElementErrorDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + throw new RuntimeException("test error in process"); + } + } + + static class TestFinishBatchErrorDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + // This has to be here. + } + + @Override + public void finishBundle(Context c) { + throw new RuntimeException("test error in finalize"); + } + } + + static class TestUnexpectedKeyedStateDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + // Will fail since this DoFn doesn't implement RequiresKeyedState. + c.keyedState(); + } + } + + private static class StrangelyNamedDoer extends DoFn { + @Override + public void processElement(ProcessContext c) { + } + } + + static class TestOutputTimestampDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Integer value = c.element(); + c.outputWithTimestamp(value, new Instant(value.longValue())); + } + } + + static class TestShiftTimestampDoFn extends DoFn { + private Duration allowedTimestampSkew; + private Duration durationToShift; + + public TestShiftTimestampDoFn(Duration allowedTimestampSkew, + Duration durationToShift) { + this.allowedTimestampSkew = allowedTimestampSkew; + this.durationToShift = durationToShift; + } + + @Override + public Duration getAllowedTimestampSkew() { + return allowedTimestampSkew; + } + @Override + public void processElement(ProcessContext c) { + Instant timestamp = c.timestamp(); + Preconditions.checkNotNull(timestamp); + Integer value = c.element(); + c.outputWithTimestamp(value, timestamp.plus(durationToShift)); + } + } + + static class TestFormatTimestampDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Preconditions.checkNotNull(c.timestamp()); + c.output("processing: " + c.element() + ", timestamp: " + c.timestamp().getMillis()); + } + } + + static class MultiFilter + extends PTransform, PCollectionTuple> { + + private static final TupleTag BY2 = new TupleTag("by2"){}; + private static final TupleTag BY3 = new TupleTag("by3"){}; + + @Override + public PCollectionTuple apply(PCollection input) { + PCollection by2 = input.apply(ParDo.of(new FilterFn(2))); + PCollection by3 = input.apply(ParDo.of(new FilterFn(3))); + return PCollectionTuple.of(BY2, by2).and(BY3, by3); + } + + static class FilterFn extends DoFn { + private final int divisor; + + FilterFn(int divisor) { + this.divisor = divisor; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + if (c.element() % divisor == 0) { + c.output(c.element()); + } + } + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testParDo() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + PCollection output = + input + .apply(ParDo.of(new TestDoFn())); + + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs)); + + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testParDoOrdered() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs).setOrdered(true); + + PCollection output = + input + .apply(ParDo.of(new TestDoFn())).setOrdered(true); + + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs).inOrder()); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testParDoEmpty() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(); + + PCollection input = createInts(p, inputs); + + PCollection output = + input + .apply(ParDo.of(new TestDoFn())); + + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs)); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testParDoWithSideOutputs() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + TupleTag mainTag = new TupleTag("main"){}; + TupleTag sideTag1 = new TupleTag("side1"){}; + TupleTag sideTag2 = new TupleTag("side2"){}; + TupleTag sideTag3 = new TupleTag("side3"){}; + TupleTag sideTagUnwritten = new TupleTag("sideUnwritten"){}; + + PCollectionTuple outputs = + input + .apply(ParDo + .of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideTag1, sideTag2, sideTag3))) + .withOutputTags( + mainTag, + TupleTagList.of(sideTag3).and(sideTag1) + .and(sideTagUnwritten).and(sideTag2))); + + DataflowAssert.that(outputs.get(mainTag)) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs)); + + DataflowAssert.that(outputs.get(sideTag1)) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideTag1)); + DataflowAssert.that(outputs.get(sideTag2)) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideTag2)); + DataflowAssert.that(outputs.get(sideTag3)) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideTag3)); + DataflowAssert.that(outputs.get(sideTagUnwritten)).containsInAnyOrder(); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testParDoWithOnlySideOutputs() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + final TupleTag mainTag = new TupleTag("main"){}; + final TupleTag sideTag = new TupleTag("side"){}; + + PCollectionTuple outputs = input.apply( + ParDo + .withOutputTags(mainTag, TupleTagList.of(sideTag)) + .of(new DoFn(){ + @Override + public void processElement(ProcessContext c) { + c.sideOutput(sideTag, c.element()); + }})); + + DataflowAssert.that(outputs.get(mainTag)).containsInAnyOrder(); + DataflowAssert.that(outputs.get(sideTag)).containsInAnyOrder(inputs); + + p.run(); + } + + @Test + public void testParDoWritingToUndeclaredSideOutput() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + TupleTag sideTag = new TupleTag("side"){}; + + PCollection output = + input + .apply(ParDo.of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideTag)))); + + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs)); + + p.run(); + } + + @Test + public void testParDoUndeclaredSideOutputLimit() { + Pipeline p = TestPipeline.create(); + PCollection input = createInts(p, Arrays.asList(3)); + + // Success for a total of 1000 outputs. + input + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + TupleTag specialSideTag = new TupleTag(){}; + c.sideOutput(specialSideTag, "side"); + c.sideOutput(specialSideTag, "side"); + c.sideOutput(specialSideTag, "side"); + + for (int i = 0; i < 998; i++) { + c.sideOutput(new TupleTag(){}, "side"); + } + }})); + p.run(); + + // Failure for a total of 1001 outputs. + input + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (int i = 0; i < 1000; i++) { + c.sideOutput(new TupleTag(){}, "side"); + } + }})); + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), + containsString("the number of side outputs has exceeded a limit")); + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testParDoWithSideInputs() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + PCollectionView sideInput1 = TestUtils.createSingletonInt(p, 11); + PCollectionView sideInputUnread = TestUtils.createSingletonInt(p, -3333); + PCollectionView sideInput2 = TestUtils.createSingletonInt(p, 222); + + PCollection output = + input + .apply(ParDo + .withSideInputs(sideInput1, sideInputUnread, sideInput2) + .of(new TestDoFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.>asList()))); + + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput + .forInput(inputs) + .andSideInputs(11, 222)); + + p.run(); + } + + @Test + public void testParDoReadingFromUnknownSideInput() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + PCollectionView sideView = TestUtils.createSingletonInt(p, 3); + + input + .apply(ParDo.of(new TestDoFn( + Arrays.>asList(sideView), + Arrays.>asList()))); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), + containsString("calling sideInput() with unknown view")); + } + } + + @Test + public void testParDoWithErrorInStartBatch() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + input + .apply(ParDo.of(new TestStartBatchErrorDoFn())); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), containsString("test error in initialize")); + } + } + + @Test + public void testParDoWithErrorInProcessElement() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + input + .apply(ParDo.of(new TestProcessElementErrorDoFn())); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), containsString("test error in process")); + } + } + + @Test + public void testParDoWithErrorInFinishBatch() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + input + .apply(ParDo.of(new TestFinishBatchErrorDoFn())); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), containsString("test error in finalize")); + } + } + + @Test + public void testParDoWithUnexpectedKeyedState() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + input + .apply(ParDo.of(new TestUnexpectedKeyedStateDoFn())); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + assertThat(exn.toString(), + containsString("Keyed state is only available")); + } + } + + @Test + public void testParDoName() { + Pipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(3, -42, 666)) + .setName("MyInput"); + + { + PCollection output1 = + input + .apply(ParDo.of(new TestDoFn())); + assertEquals("Test.out", output1.getName()); + } + + { + PCollection output2 = + input + .apply(ParDo.named("MyParDo").of(new TestDoFn())); + assertEquals("MyParDo.out", output2.getName()); + } + + { + PCollection output3 = + input + .apply(ParDo.of(new TestDoFn()).named("HerParDo")); + assertEquals("HerParDo.out", output3.getName()); + } + + { + PCollection output4 = + input + .apply(ParDo.of(new TestDoFn()).named("TestDoFn")); + assertEquals("TestDoFn.out", output4.getName()); + } + + { + PCollection output5 = + input + .apply(ParDo.of(new StrangelyNamedDoer())); + assertEquals("StrangelyNamedDoer.out", + output5.getName()); + } + } + + @Test + public void testParDoWithSideOutputsName() { + Pipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(3, -42, 666)) + .setName("MyInput"); + + TupleTag mainTag = new TupleTag("main"){}; + TupleTag sideTag1 = new TupleTag("side1"){}; + TupleTag sideTag2 = new TupleTag("side2"){}; + TupleTag sideTag3 = new TupleTag("side3"){}; + TupleTag sideTagUnwritten = new TupleTag("sideUnwritten"){}; + + PCollectionTuple outputs = + input + .apply(ParDo + .named("MyParDo") + .of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideTag1, sideTag2, sideTag3))) + .withOutputTags( + mainTag, + TupleTagList.of(sideTag3).and(sideTag1) + .and(sideTagUnwritten).and(sideTag2))); + + assertEquals("MyParDo.main", outputs.get(mainTag).getName()); + assertEquals("MyParDo.side1", outputs.get(sideTag1).getName()); + assertEquals("MyParDo.side2", outputs.get(sideTag2).getName()); + assertEquals("MyParDo.side3", outputs.get(sideTag3).getName()); + assertEquals("MyParDo.sideUnwritten", + outputs.get(sideTagUnwritten).getName()); + } + + @Test + public void testParDoInCustomTransform() { + Pipeline p = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection input = createInts(p, inputs); + + PCollection output = + input + .apply(new PTransform, PCollection>() { + @Override + public PCollection apply(PCollection input) { + return input.apply(ParDo.of(new TestDoFn())); + } + }); + + // Test that Coder inference of the result works through + // user-defined PTransforms. + DataflowAssert.that(output) + .satisfies(TestDoFn.HasExpectedOutput.forInput(inputs)); + + p.run(); + } + + @Test + public void testMultiOutputChaining() { + Pipeline p = TestPipeline.create(); + + PCollection input = createInts(p, Arrays.asList(3, 4, 5, 6)); + + PCollectionTuple filters = input.apply(new MultiFilter()); + PCollection by2 = filters.get(MultiFilter.BY2); + PCollection by3 = filters.get(MultiFilter.BY3); + + // Apply additional filters to each operation. + PCollection by2then3 = by2 + .apply(ParDo.of(new MultiFilter.FilterFn(3))); + PCollection by3then2 = by3 + .apply(ParDo.of(new MultiFilter.FilterFn(2))); + + DataflowAssert.that(by2then3).containsInAnyOrder(6); + DataflowAssert.that(by3then2).containsInAnyOrder(6); + p.run(); + } + + @Test + public void testJsonEscaping() { + // Declare an arbitrary function and make sure we can serialize it + DoFn doFn = new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + 1); + } + }; + + byte[] serializedBytes = serializeToByteArray(doFn); + String serializedJson = byteArrayToJsonString(serializedBytes); + assertArrayEquals( + serializedBytes, jsonStringToByteArray(serializedJson)); + } + + private static class TestDummy { } + + private static class TestDummyCoder extends AtomicCoder { + private TestDummyCoder() { } + private static final TestDummyCoder INSTANCE = new TestDummyCoder(); + + @JsonCreator + public static TestDummyCoder of() { + return INSTANCE; + } + + public static List getInstanceComponents(TestDummy exampleValue) { + return Collections.emptyList(); + } + + @Override + public void encode(TestDummy value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public TestDummy decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new TestDummy(); + } + + @Override + public boolean isDeterministic() { return true; } + + @Override + public boolean isRegisterByteSizeObserverCheap(TestDummy value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + TestDummy value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(0L); + } + } + + private static class SideOutputDummyFn extends DoFn { + private TupleTag sideTag; + public SideOutputDummyFn(TupleTag sideTag) { + this.sideTag = sideTag; + } + @Override + public void processElement(ProcessContext c) { + c.output(1); + c.sideOutput(sideTag, new TestDummy()); + } + } + + private static class MainOutputDummyFn extends DoFn { + private TupleTag sideTag; + public MainOutputDummyFn(TupleTag sideTag) { + this.sideTag = sideTag; + } + @Override + public void processElement(ProcessContext c) { + c.output(new TestDummy()); + c.sideOutput(sideTag, 1); + } + } + + @Test + public void testSideOutputUnknownCoder() { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + // Expect a fail, but it should be a NoCoderException + final TupleTag mainTag = new TupleTag(); + final TupleTag sideTag = new TupleTag(); + input.apply(ParDo.of(new SideOutputDummyFn(sideTag)) + .withOutputTags(mainTag, TupleTagList.of(sideTag))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("unable to infer a default Coder"); + pipeline.run(); + } + + @Test + public void testSideOutputUnregisteredExplicitCoder() { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + final TupleTag mainTag = new TupleTag(); + final TupleTag sideTag = new TupleTag(); + PCollectionTuple outputTuple = input.apply(ParDo.of(new SideOutputDummyFn(sideTag)) + .withOutputTags(mainTag, TupleTagList.of(sideTag))); + + outputTuple.get(sideTag) + .setCoder(new TestDummyCoder()); + + pipeline.run(); + } + + @Test + public void testMainOutputUnregisteredExplicitCoder() { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + final TupleTag mainTag = new TupleTag(); + final TupleTag sideTag = new TupleTag() {}; + PCollectionTuple outputTuple = input.apply(ParDo.of(new MainOutputDummyFn(sideTag)) + .withOutputTags(mainTag, TupleTagList.of(sideTag))); + + outputTuple.get(mainTag) + .setCoder(new TestDummyCoder()); + + pipeline.run(); + } + + @Test + public void testParDoOutputWithTimestamp() { + Pipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(3, 42, 6)).setOrdered(true); + + PCollection output = + input + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.ZERO, Duration.ZERO))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + DataflowAssert.that(output).containsInAnyOrder( + "processing: 3, timestamp: 3", + "processing: 42, timestamp: 42", + "processing: 6, timestamp: 6"); + + p.run(); + } + + @Test + public void testParDoShiftTimestamp() { + Pipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(3, 42, 6)).setOrdered(true); + + PCollection output = + input + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.millis(1000), + Duration.millis(-1000)))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + DataflowAssert.that(output).containsInAnyOrder( + "processing: 3, timestamp: -997", + "processing: 42, timestamp: -958", + "processing: 6, timestamp: -994"); + + p.run(); + } + + @Test + public void testParDoShiftTimestampInvalid() { + Pipeline p = TestPipeline.create(); + + createInts(p, Arrays.asList(3, 42, 6)).setOrdered(true) + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.millis(1000), + Duration.millis(-1001)))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + try { + p.run(); + fail("should have failed"); + } catch (RuntimeException exn) { + // expected + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java new file mode 100644 index 000000000000..0d19f082ee07 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.createInts; +import static com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for Partition + */ +@RunWith(JUnit4.class) +public class PartitionTest implements Serializable { + static class ModFn implements PartitionFn { + public int partitionFor(Integer elem, int numPartitions) { + return elem % numPartitions; + } + } + + static class IdentityFn implements PartitionFn { + public int partitionFor(Integer elem, int numPartitions) { + return elem; + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testEvenOddPartition() { + TestPipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(591, 11789, 1257, 24578, 24799, 307)); + + PCollectionList outputs = input.apply(Partition.of(2, new ModFn())); + assertTrue(outputs.size() == 2); + DataflowAssert.that(outputs.get(0)).containsInAnyOrder(24578); + DataflowAssert.that(outputs.get(1)).containsInAnyOrder(591, 11789, 1257, + 24799, 307); + p.run(); + } + + @Test + public void testModPartition() { + TestPipeline p = TestPipeline.create(); + + PCollection input = + createInts(p, Arrays.asList(1, 2, 4, 5)); + + PCollectionList outputs = input.apply(Partition.of(3, new ModFn())); + assertTrue(outputs.size() == 3); + DataflowAssert.that(outputs.get(0)).containsInAnyOrder(); + DataflowAssert.that(outputs.get(1)).containsInAnyOrder(1, 4); + DataflowAssert.that(outputs.get(2)).containsInAnyOrder(2, 5); + p.run(); + } + + @Test + public void testOutOfBoundsPartitions() { + TestPipeline p = TestPipeline.create(); + + PCollection input = createInts(p, Arrays.asList(-1)); + + PCollectionList outputs = + input.apply(Partition.of(5, new IdentityFn())); + + try { + p.run(); + } catch (RuntimeException e) { + assertThat(e.toString(), containsString( + "Partition function returned out of bounds index: -1 not in [0..5)")); + } + } + + @Test + public void testZeroNumPartitions() { + TestPipeline p = TestPipeline.create(); + + PCollection input = createInts(p, Arrays.asList(591)); + + try { + PCollectionList outputs = + input.apply(Partition.of(0, new IdentityFn())); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + assertThat(exn.toString(), containsString("numPartitions must be > 0")); + } + } + + @Test + public void testDroppedPartition() { + TestPipeline p = TestPipeline.create(); + + PCollection input = createInts(p, + Arrays.asList(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); + + // Compute the set of integers either 1 or 2 mod 3, the hard way. + PCollectionList outputs = + input.apply(Partition.of(3, new ModFn())); + + List> outputsList = new ArrayList<>(outputs.getAll()); + outputsList.remove(0); + outputs = PCollectionList.of(outputsList); + assertTrue(outputs.size() == 2); + + PCollection output = outputs.apply(Flatten.create()); + DataflowAssert.that(output).containsInAnyOrder(2, 4, 5, 7, 8, 10, 11); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RateLimitingTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RateLimitingTest.java new file mode 100644 index 000000000000..d6de05af6d3f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RateLimitingTest.java @@ -0,0 +1,225 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; + +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Tests for RateLimiter. + */ +@RunWith(JUnit4.class) +public class RateLimitingTest { + + /** + * Pass-thru function. + */ + private static class IdentityFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + + /** + * Introduces a delay in processing, then passes thru elements. + */ + private static class DelayFn extends DoFn { + public static final long DELAY_MS = 250; + + @Override + public void processElement(ProcessContext c) { + try { + Thread.sleep(DELAY_MS); + } catch (InterruptedException e) { + e.printStackTrace(); + throw new RuntimeException("Interrupted"); + } + c.output(c.element()); + } + } + + /** + * Throws an exception after some number of calls. + */ + private static class ExceptionThrowingFn extends DoFn { + private final AtomicInteger numSuccesses; + private final AtomicInteger numProcessed = new AtomicInteger(); + private final AtomicInteger numFailures = new AtomicInteger(); + + private ExceptionThrowingFn(int numSuccesses) { + this.numSuccesses = new AtomicInteger(numSuccesses); + } + + @Override + public void processElement(ProcessContext c) { + numProcessed.incrementAndGet(); + if (numSuccesses.decrementAndGet() > 0) { + c.output(c.element()); + return; + } + + numFailures.incrementAndGet(); + throw new RuntimeException("Expected failure"); + } + } + + /** + * Measures concurrency of the processElement method. + * + *

Note: this only works when + * {@link DirectPipelineRunner#testSerializability} is disabled, otherwise + * the counters are not available after the run. + */ + private static class ConcurrencyMeasuringFn extends DoFn { + private int concurrentElements = 0; + private int maxConcurrency = 0; + + @Override + public void processElement(ProcessContext c) { + synchronized (this) { + concurrentElements++; + if (concurrentElements > maxConcurrency) { + maxConcurrency = concurrentElements; + } + } + + c.output(c.element()); + + synchronized (this) { + concurrentElements--; + } + } + } + + @Test + public void testRateLimitingMax() { + int n = 10; + double rate = 10.0; + long duration = runWithRate(n, rate, new IdentityFn()); + + long perElementPause = (long) (1000L / rate); + long minDuration = (n - 1) * perElementPause; + Assert.assertThat(duration, greaterThan(minDuration)); + } + + @Test(timeout = 5000L) + public void testExceptionHandling() { + ExceptionThrowingFn fn = new ExceptionThrowingFn<>(10); + try { + runWithRate(100, 0.0, fn); + Assert.fail("Expected exception to propagate"); + } catch (RuntimeException e) { + Assert.assertThat(e.getMessage(), containsString("Expected failure")); + } + + // Should have processed 10 elements, but stopped before processing all + // of them. + Assert.assertThat(fn.numProcessed.get(), + is(both(greaterThanOrEqualTo(10)) + .and(lessThan(100)))); + + // The first failure should prevent the scheduling of any more elements. + Assert.assertThat(fn.numFailures.get(), + is(both(greaterThanOrEqualTo(1)) + .and(lessThan(RateLimiting.DEFAULT_MAX_PARALLELISM)))); + } + + /** + * Test exception handling on the last element to be processed. + */ + @Test(timeout = 5000L) + public void testExceptionHandling2() { + ExceptionThrowingFn fn = new ExceptionThrowingFn<>(10); + try { + runWithRate(10, 0.0, fn); + Assert.fail("Expected exception to propagate"); + } catch (RuntimeException e) { + Assert.assertThat(e.getMessage(), containsString("Expected failure")); + } + + // Should have processed 10 elements, but stopped before processing all + // of them. + Assert.assertEquals(10, fn.numProcessed.get()); + Assert.assertEquals(1, fn.numFailures.get()); + } + + /** + * Provides more elements than can be scheduled at once, testing that the + * backlog limit is applied. + */ + @Test + public void testBacklogLimiter() { + long duration = runWithRate(2 * RateLimiting.DEFAULT_MAX_PARALLELISM, + -1.0 /* unlimited */, new DelayFn()); + + // Should take > 2x the delay interval, since no more than half the elements + // can be scheduled at once. + Assert.assertThat(duration, + greaterThan(2 * DelayFn.DELAY_MS)); + } + + private long runWithRate(int numElements, double rateLimit, + DoFn doFn) { + DirectPipeline p = DirectPipeline.createForTest(); + // Run with serializability testing disabled so that our tests can inspect + // the DoFns after the test. + p.getRunner().withSerializabilityTesting(false); + + ArrayList data = new ArrayList<>(numElements); + for (int i = 0; i < numElements; ++i) { + data.add(i); + } + + PCollection input = TestUtils.createInts(p, data); + + ConcurrencyMeasuringFn downstream = new ConcurrencyMeasuringFn<>(); + + PCollection output = input + .apply(RateLimiting.perWorker(doFn) + .withRateLimit(rateLimit)) + .apply(ParDo + .of(downstream)); + + long startTime = System.currentTimeMillis(); + + DirectPipelineRunner.EvaluationResults results = p.run(); + + // Downstream methods should not see parallel threads. + Assert.assertEquals(1, downstream.maxConcurrency); + + long endTime = System.currentTimeMillis(); + return endTime - startTime; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java new file mode 100644 index 000000000000..a44fa2d39103 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for RemovedDuplicates. + */ +@RunWith(JUnit4.class) +public class RemoveDuplicatesTest { + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testRemoveDuplicates() { + List strings = Arrays.asList( + "k1", + "k5", + "k5", + "k2", + "k1", + "k2", + "k3"); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output) + .containsInAnyOrder("k1", "k5", "k2", "k3"); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testRemoveDuplicatesEmpty() { + List strings = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output) + .containsInAnyOrder(); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java new file mode 100644 index 000000000000..7c51d096fe4e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.api.client.util.Joiner; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Tests for Sample transform. + */ +@RunWith(JUnit4.class) +public class SampleTest { + static final Integer[] EMPTY = new Integer[] { }; + static final Integer[] DATA = new Integer[] {1, 2, 3, 4, 5}; + static final Integer[] REPEATED_DATA = new Integer[] {1, 1, 2, 2, 3, 3, 4, 4, 5, 5}; + + /** + * Verifies that the result of a Sample operation contains the expected number of elements, + * and that those elements are a subset of the items in expected. + */ + public static class VerifyCorrectSample + implements SerializableFunction, Void> { + private T[] expectedValues; + private int expectedSize; + + /** + * expectedSize is the number of elements that the Sample should contain. expected is the set + * of elements that the sample may contain. + */ + VerifyCorrectSample(int expectedSize, T... expected) { + this.expectedValues = expected; + this.expectedSize = expectedSize; + } + + @Override + public Void apply(Iterable in) { + List actual = new ArrayList<>(); + for (T elem : in) { + actual.add(elem); + } + + assertEquals(expectedSize, actual.size()); + + Collections.sort(actual); // We assume that @expected is already sorted. + int i = 0; // Index into @expected + for (T s : actual) { + boolean matchFound = false; + for (; i < expectedValues.length; i++) { + if (s.equals(expectedValues[i])) { + matchFound = true; + break; + } + } + assertTrue("Invalid sample: " + Joiner.on(',').join(actual), matchFound); + i++; // Don't match the same element again. + } + return null; + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSample() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA)) + .setCoder(BigEndianIntegerCoder.of()); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(3)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(3, DATA)); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSampleEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(EMPTY)) + .setCoder(BigEndianIntegerCoder.of()); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(3)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(0, EMPTY)); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSampleZero() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA)) + .setCoder(BigEndianIntegerCoder.of()); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(0)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(0, DATA)); + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSampleInsufficientElements() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA)) + .setCoder(BigEndianIntegerCoder.of()); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(10)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(5, DATA)); + p.run(); + } + + @Test(expected = IllegalArgumentException.class) + public void testSampleNegative() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA)) + .setCoder(BigEndianIntegerCoder.of()); + input.apply(Sample.fixedSizeGlobally(-1)); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSampleMultiplicity() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(REPEATED_DATA)) + .setCoder(BigEndianIntegerCoder.of()); + // At least one value must be selected with multiplicity. + PCollection> output = input.apply( + Sample.fixedSizeGlobally(6)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(6, REPEATED_DATA)); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java new file mode 100644 index 000000000000..909dcba9981f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests of Min, Max, Mean, and Sum. + */ +@RunWith(JUnit4.class) +public class SimpleStatsFnsTest { + static final double DOUBLE_COMPARISON_ACCURACY = 1e-7; + + private static class TestCase> { + final List data; + final N min; + final N max; + final N sum; + final Double mean; + + public TestCase(N min, N max, N sum, N... values) { + this.data = Arrays.asList(values); + this.min = min; + this.max = max; + this.sum = sum; + this.mean = + values.length == 0 ? 0.0 : sum.doubleValue() / values.length; + } + } + + static final List> DOUBLE_CASES = Arrays.asList( + new TestCase<>(-312.31, 6312.31, 11629.13, + -312.31, 29.13, 112.158, 6312.31, -312.158, -312.158, 112.158, + -312.31, 6312.31, 0.0), + new TestCase<>(3.14, 3.14, 3.14, 3.14), + new TestCase<>(Double.MAX_VALUE, Double.MIN_NORMAL, 0.0)); + + static final List> LONG_CASES = Arrays.asList( + new TestCase<>(-50000000000000000L, + 70000000000000000L, + 60000033123213121L, + 0L, 1L, 10000000000000000L, -50000000000000000L, + 70000000000000000L, 0L, 10000000000000000L, -1L, + -50000000000000000L, 70000000000000000L, 33123213121L), + new TestCase<>(3L, 3L, 3L, 3L), + new TestCase<>(Long.MAX_VALUE, Long.MIN_VALUE, 0L)); + + static final List> INTEGER_CASES = Arrays.asList( + new TestCase<>(-3, 6, 22, + 1, -3, 2, 6, 3, 4, -3, 5, 6, 1), + new TestCase<>(3, 3, 3, 3), + new TestCase<>(Integer.MAX_VALUE, Integer.MIN_VALUE, 0)); + + @Test + public void testDoubleStats() { + for (TestCase t : DOUBLE_CASES) { + assertEquals(t.sum, new Sum.SumDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.min, new Min.MinDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.max, new Max.MaxDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + } + } + + @Test + public void testIntegerStats() { + for (TestCase t : INTEGER_CASES) { + assertEquals(t.sum, new Sum.SumIntegerFn().apply(t.data)); + assertEquals(t.min, new Min.MinIntegerFn().apply(t.data)); + assertEquals(t.max, new Max.MaxIntegerFn().apply(t.data)); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data)); + } + } + + @Test + public void testLongStats() { + for (TestCase t : LONG_CASES) { + assertEquals(t.sum, new Sum.SumLongFn().apply(t.data)); + assertEquals(t.min, new Min.MinLongFn().apply(t.data)); + assertEquals(t.max, new Max.MaxLongFn().apply(t.data)); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data)); + } + } + + @Test + public void testMeanCountSumSerializable() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.of(KV.of(1L, 1.5), KV.of(2L, 7.3))) + .setCoder(KvCoder.of(VarLongCoder.of(), DoubleCoder.of())); + + PCollection> meanPerKey = + input.apply(Mean.perKey()); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java new file mode 100644 index 000000000000..63625a7f5f2b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.runners.RecordingPipelineVisitor; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +/** Tests for Top */ +@RunWith(JUnit4.class) +public class TopTest { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @SuppressWarnings("unchecked") + static final String[] COLLECTION = new String[] { + "a", "bb", "c", "c", "z" + }; + + @SuppressWarnings("unchecked") + static final String[] EMPTY_COLLECTION = new String[] { + }; + + @SuppressWarnings("unchecked") + static final KV[] TABLE = new KV[] { + KV.of("a", 1), + KV.of("a", 2), + KV.of("a", 3), + KV.of("b", 1), + KV.of("b", 10), + KV.of("b", 10), + KV.of("b", 100), + }; + + @SuppressWarnings("unchecked") + static final KV[] EMPTY_TABLE = new KV[] { + }; + + public PCollection> createInputTable(Pipeline p) { + return p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + } + + public PCollection> createEmptyInputTable(Pipeline p) { + return p.apply(Create.of(Arrays.asList(EMPTY_TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + } + + @Test + @SuppressWarnings("unchecked") + public void testTop() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + PCollection> top1 = input.apply(Top.of(1, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(2)); + PCollection> top3 = input.apply(Top.smallest(3)); + + PCollection>> largestPerKey = createInputTable(p) + .apply(Top.largestPerKey(2)); + PCollection>> smallestPerKey = createInputTable(p) + .apply(Top.smallestPerKey(2)); + + EvaluationResults results = p.run(); + + assertThat(results.getPCollection(top1).get(0), contains("bb")); + assertThat(results.getPCollection(top2).get(0), contains("z", "c")); + assertThat(results.getPCollection(top3).get(0), contains("a", "bb", "c")); + assertThat(results.getPCollection(largestPerKey), containsInAnyOrder( + KV.of("a", Arrays.asList(3, 2)), + KV.of("b", Arrays.asList(100, 10)))); + assertThat(results.getPCollection(smallestPerKey), containsInAnyOrder( + KV.of("a", Arrays.asList(1, 2)), + KV.of("b", Arrays.asList(1, 10)))); + } + + @Test + @SuppressWarnings("unchecked") + public void testTopEmpty() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(EMPTY_COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + PCollection> top1 = input.apply(Top.of(1, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(2)); + PCollection> top3 = input.apply(Top.smallest(3)); + + PCollection>> largestPerKey = createEmptyInputTable(p) + .apply(Top.largestPerKey(2)); + PCollection>> smallestPerKey = createEmptyInputTable(p) + .apply(Top.smallestPerKey(2)); + + EvaluationResults results = p.run(); + + assertThat(results.getPCollection(top1).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(top2).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(top3).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(largestPerKey), containsInAnyOrder()); + assertThat(results.getPCollection(smallestPerKey), containsInAnyOrder()); + } + + @Test + @SuppressWarnings("unchecked") + public void testTopZero() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + PCollection> top1 = input.apply(Top.of(0, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(0)); + PCollection> top3 = input.apply(Top.smallest(0)); + + PCollection>> largestPerKey = createInputTable(p) + .apply(Top.largestPerKey(0)); + + PCollection>> smallestPerKey = createInputTable(p) + .apply(Top.smallestPerKey(0)); + + EvaluationResults results = p.run(); + + assertThat(results.getPCollection(top1).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(top2).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(top3).get(0), containsInAnyOrder()); + assertThat(results.getPCollection(largestPerKey), containsInAnyOrder( + KV.of("a", Arrays.asList()), + KV.of("b", Arrays.asList()))); + assertThat(results.getPCollection(smallestPerKey), containsInAnyOrder( + KV.of("a", Arrays.asList()), + KV.of("b", Arrays.asList()))); + } + + // This is a purely compile-time test. If the code compiles, then it worked. + @Test + public void testPerKeySerializabilityRequirement() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + PCollection>> top1 = createInputTable(p) + .apply(Top.perKey(1, + new IntegerComparator())); + + PCollection>> top2 = createInputTable(p) + .apply(Top.perKey(1, + new IntegerComparator2())); + } + + @Test + public void testCountConstraint() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString(">= 0")); + + input.apply(Top.of(-1, new OrderByLength())); + } + + @Test + public void testTransformName() { + DirectPipeline p = DirectPipeline.createForTest(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))) + .setCoder(StringUtf8Coder.of()); + + PTransform, PCollection>> top = Top + .of(10, new OrderByLength()); + input.apply(top); + + p.traverseTopologically(new RecordingPipelineVisitor()); + // Check that the transform is named "Top" rather than "Combine". + assertThat(p.getFullName(top), Matchers.startsWith("Top")); + } + + static class OrderByLength implements Comparator, Serializable { + @Override + public int compare(String a, String b) { + if (a.length() != b.length()) { + return a.length() - b.length(); + } else { + return a.compareTo(b); + } + } + } + + static class IntegerComparator implements Comparator, Serializable { + @Override + public int compare(Integer o1, Integer o2) { + return o1.compareTo(o2); + } + } + + static class IntegerComparator2 implements SerializableComparator { + @Override + public int compare(Integer o1, Integer o2) { + return o1.compareTo(o2); + } + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java new file mode 100644 index 000000000000..497d8fc8406e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for Values transform. + */ +@RunWith(JUnit4.class) +public class ValuesTest { + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("four", 4), + KV.of("dup", 4) + }; + + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testValues() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection output = input.apply(Values.create()); + + DataflowAssert.that(output) + .containsInAnyOrder(1, 2, 3, 4, 4); + + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testValuesOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + input.setOrdered(true); + PCollection output = + input.apply(Values.create()).setOrdered(true); + + DataflowAssert.that(output) + .containsInOrder(1, 2, 3, 4, 4); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testValuesEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE))).setCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())); + + PCollection output = input.apply(Values.create()); + + DataflowAssert.that(output) + .containsInAnyOrder(); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java new file mode 100644 index 000000000000..3a7c8187d923 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.CoreMatchers.isA; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.NoSuchElementException; + +/** + * Tests for {@link View}. See also {@link ParDoTest} which + * provides additional coverage since views can only be + * observed via {@link ParDo}. + */ +@RunWith(JUnit4.class) +public class ViewTest implements Serializable { + // This test is Serializable, just so that it's easy to have + // anonymous inner classes inside the non-static test methods. + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testSingletonSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = pipeline + .apply(Create.of(47)) + .apply(View.asSingleton()); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + DataflowAssert.that(output) + .containsInAnyOrder(47, 47, 47); + + pipeline.run(); + } + + @Test + public void testEmptySingletonSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = pipeline + .apply(Create.of()) + .setCoder(VarIntCoder.of()) + .apply(View.asSingleton()); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + thrown.expect(RuntimeException.class); + thrown.expectCause(isA(NoSuchElementException.class)); + thrown.expectMessage("Empty"); + thrown.expectMessage("PCollection"); + thrown.expectMessage("singleton"); + + pipeline.run(); + } + + @Test + public void testNonSingletonSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = pipeline + .apply(Create.of(1, 2, 3)) + .apply(View.asSingleton()); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + thrown.expect(RuntimeException.class); + thrown.expectCause(isA(IllegalArgumentException.class)); + thrown.expectMessage("PCollection"); + thrown.expectMessage("more than one"); + thrown.expectMessage("singleton"); + + pipeline.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testIterableSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView, ?> view = pipeline + .apply(Create.of(11, 13, 17, 23)) + .apply(View.asIterable()); + + PCollection output = pipeline + .apply(Create.of(29, 31)) + .apply(ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + 11, 13, 17, 23, + 11, 13, 17, 23); + + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java new file mode 100644 index 000000000000..3e4e359022c9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for ExtractKeys transform. + */ +@RunWith(JUnit4.class) +public class WithKeysTest { + static final String[] COLLECTION = new String[] { + "a", + "aa", + "b", + "bb", + "bbb" + }; + + static final List> WITH_KEYS = Arrays.asList( + KV.of(1, "a"), + KV.of(2, "aa"), + KV.of(1, "b"), + KV.of(2, "bb"), + KV.of(3, "bbb") + ); + + static final List> WITH_CONST_KEYS = Arrays.asList( + KV.of(100, "a"), + KV.of(100, "aa"), + KV.of(100, "b"), + KV.of(100, "bb"), + KV.of(100, "bbb") + ); + + @Test + public void testExtractKeys() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))).setCoder( + StringUtf8Coder.of()); + + PCollection> output = input.apply(WithKeys.of( + new LengthAsKey())); + DataflowAssert.that(output) + .containsInAnyOrder(WITH_KEYS); + + p.run(); + } + + // TODO: setOrdered(true) isn't supported yet by the Dataflow service. + @Test + public void testExtractKeysOrdered() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))).setCoder( + StringUtf8Coder.of()); + + input.setOrdered(true); + PCollection> output = input.apply(WithKeys.of( + new LengthAsKey())).setOrdered(true); + DataflowAssert.that(output) + .containsInAnyOrder(WITH_KEYS); + + p.run(); + } + + @Test + public void testConstantKeys() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION))).setCoder( + StringUtf8Coder.of()); + + PCollection> output = + input.apply(WithKeys.of(100)); + DataflowAssert.that(output) + .containsInAnyOrder(WITH_CONST_KEYS); + + p.run(); + } + + /** + * Key a value by its length. + */ + public static class LengthAsKey + implements SerializableFunction { + @Override + public Integer apply(String value) { + return value.length(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java new file mode 100644 index 000000000000..afb8a998798e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests the CoGbkResult.CoGbkResultCoder. + */ +@RunWith(JUnit4.class) +public class CoGbkResultCoderTest { + + @Test + public void testSerializationDeserialization() { + CoGbkResultSchema schema = + new CoGbkResultSchema(TupleTagList.of(new TupleTag()).and( + new TupleTag())); + UnionCoder unionCoder = + UnionCoder.of(Arrays.>asList(StringUtf8Coder.of(), + DoubleCoder.of())); + CoGbkResultCoder newCoder = CoGbkResultCoder.of(schema, unionCoder); + CloudObject encoding = newCoder.asCloudObject(); + Coder decodedCoder = Serializer.deserialize(encoding, Coder.class); + assertEquals(newCoder, decodedCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java new file mode 100644 index 000000000000..016ba15d5ae4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java @@ -0,0 +1,348 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matcher; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +/** + * Tests for CoGroupByKeyTest. Implements Serializable for anonymous DoFns. + */ +@RunWith(JUnit4.class) +public class CoGroupByKeyTest implements Serializable { + + /** + * Converts the given list into a PCollection belonging to the provided + * Pipeline in such a way that coder inference needs to be performed. + */ + private PCollection> createInput( + Pipeline p, List> list) { + return p + .apply(Create.of(list)) + // Create doesn't infer coders for parameterized types. + .setCoder( + KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of())) + // Do a dummy transform so consumers must deal with coder inference. + .apply(ParDo.of(new DoFn, + KV>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + } + })); + } + + /** + * Returns a PCollection> containing the + * results of the CoGbk over 3 PCollection>, each of + * which correlates a customer id to purchases, addresses, or names, + * respectively. + */ + private PCollection> buildPurchasesCoGbk( + Pipeline p, + TupleTag purchasesTag, + TupleTag addressesTag, + TupleTag namesTag) { + List> idToPurchases = + Arrays.asList( + KV.of(2, "Boat"), + KV.of(1, "Shoes"), + KV.of(3, "Car"), + KV.of(1, "Book"), + KV.of(10, "Pens"), + KV.of(8, "House"), + KV.of(4, "Suit"), + KV.of(11, "House"), + KV.of(14, "Shoes"), + KV.of(2, "Suit"), + KV.of(8, "Suit Case"), + KV.of(3, "House")); + + List> idToAddress = + Arrays.asList( + KV.of(2, "53 S. 3rd"), + KV.of(10, "383 Jackson Street"), + KV.of(20, "3 W. Arizona"), + KV.of(3, "29 School Rd"), + KV.of(8, "6 Watling Rd")); + + List> idToName = + Arrays.asList( + KV.of(1, "John Smith"), + KV.of(2, "Sally James"), + KV.of(8, "Jeffery Spalding"), + KV.of(20, "Joan Lichtfield")); + + PCollection> purchasesTable = + createInput(p, idToPurchases); + + PCollection> addressTable = + createInput(p, idToAddress); + + PCollection> nameTable = + createInput(p, idToName); + + PCollection> coGbkResults = + KeyedPCollectionTuple.of(namesTag, nameTable) + .and(addressesTag, addressTable) + .and(purchasesTag, purchasesTable) + .apply(CoGroupByKey.create()); + return coGbkResults; + } + + @Test + public void testCoGroupByKey() { + TupleTag namesTag = new TupleTag<>(); + TupleTag addressesTag = new TupleTag<>(); + TupleTag purchasesTag = new TupleTag<>(); + + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection> coGbkResults = + buildPurchasesCoGbk(p, purchasesTag, addressesTag, namesTag); + + EvaluationResults results = p.run(); + + List> finalResult = + results.getPCollection(coGbkResults); + + HashMap>> namesMatchers = + new HashMap>>() { + { + put(1, containsInAnyOrder("John Smith")); + put(2, containsInAnyOrder("Sally James")); + put(8, containsInAnyOrder("Jeffery Spalding")); + put(20, containsInAnyOrder("Joan Lichtfield")); + } + }; + + HashMap>> addressesMatchers = + new HashMap>>() { + { + put(2, containsInAnyOrder("53 S. 3rd")); + put(3, containsInAnyOrder("29 School Rd")); + put(8, containsInAnyOrder("6 Watling Rd")); + put(10, containsInAnyOrder("383 Jackson Street")); + put(20, containsInAnyOrder("3 W. Arizona")); + } + }; + + HashMap>> purchasesMatchers = + new HashMap>>() { + { + put(1, containsInAnyOrder("Shoes", "Book")); + put(2, containsInAnyOrder("Suit", "Boat")); + put(3, containsInAnyOrder("Car", "House")); + put(4, containsInAnyOrder("Suit")); + put(8, containsInAnyOrder("House", "Suit Case")); + put(10, containsInAnyOrder("Pens")); + put(11, containsInAnyOrder("House")); + put(14, containsInAnyOrder("Shoes")); + } + }; + + // TODO: Figure out a way to do a hamcrest matcher for CoGbkResults. + for (KV result : finalResult) { + int key = result.getKey(); + CoGbkResult row = result.getValue(); + checkValuesMatch(key, namesMatchers, row, namesTag); + checkValuesMatch(key, addressesMatchers, row, addressesTag); + checkValuesMatch(key, purchasesMatchers, row, purchasesTag); + + } + + } + + /** + * Checks that the values for the given tag in the given row matches the + * expected values for the given key in the given matchers map. + */ + private void checkValuesMatch( + K key, + HashMap>> matchers, + CoGbkResult row, + TupleTag tag) { + Iterable taggedValues = row.getAll(tag); + if (taggedValues.iterator().hasNext()) { + assertThat(taggedValues, matchers.get(key)); + } else { + assertNull(matchers.get(key)); + } + } + + /** + * A DoFn used in testCoGroupByKeyHandleResults(), to test processing the + * results of a CoGroupByKey. + */ + private static class CorrelatePurchaseCountForAddressesWithoutNamesFn extends + DoFn, KV> { + private final TupleTag purchasesTag; + + private final TupleTag addressesTag; + + private final TupleTag namesTag; + + private CorrelatePurchaseCountForAddressesWithoutNamesFn( + TupleTag purchasesTag, + TupleTag addressesTag, + TupleTag namesTag) { + this.purchasesTag = purchasesTag; + this.addressesTag = addressesTag; + this.namesTag = namesTag; + } + + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + CoGbkResult row = e.getValue(); + // Don't actually care about the id. + Iterable names = row.getAll(namesTag); + if (names.iterator().hasNext()) { + // Nothing to do. There was a name. + return; + } + Iterable addresses = row.getAll(addressesTag); + if (!addresses.iterator().hasNext()) { + // Nothing to do, there was no address. + return; + } + // Buffer the addresses so we can accredit all of them with + // corresponding purchases. All addresses are for the same id, so + // if there are multiple, we apply the same purchase count to all. + ArrayList addressList = new ArrayList(); + for (String address : addresses) { + addressList.add(address); + } + + Iterable purchases = row.getAll(purchasesTag); + + int purchaseCount = 0; + for (String purchase : purchases) { + purchaseCount++; + } + + for (String address : addressList) { + c.output(KV.of(address, purchaseCount)); + } + } + } + + /** + * Tests that the consuming DoFn + * (CorrelatePurchaseCountForAddressesWithoutNamesFn) performs as expected. + */ + @SuppressWarnings("unchecked") + @Test + public void testConsumingDoFn() { + TupleTag purchasesTag = new TupleTag<>(); + TupleTag addressesTag = new TupleTag<>(); + TupleTag namesTag = new TupleTag<>(); + + // result1 should get filtered out because it has a name. + CoGbkResult result1 = CoGbkResult + .of(purchasesTag, Arrays.asList("3a", "3b")) + .and(addressesTag, Arrays.asList("2a", "2b")) + .and(namesTag, Arrays.asList("1a")); + // result 2 should be counted because it has an address and purchases. + CoGbkResult result2 = CoGbkResult + .of(purchasesTag, Arrays.asList("5a", "5b")) + .and(addressesTag, Arrays.asList("4a")) + .and(namesTag, new ArrayList()); + // result 3 should not be counted because it has no addresses. + CoGbkResult result3 = CoGbkResult + .of(purchasesTag, Arrays.asList("7a", "7b")) + .and(addressesTag, new ArrayList()) + .and(namesTag, new ArrayList()); + // result 4 should be counted as 0, because it has no purchases. + CoGbkResult result4 = CoGbkResult + .of(purchasesTag, new ArrayList()) + .and(addressesTag, Arrays.asList("8a")) + .and(namesTag, new ArrayList()); + + List> results = + DoFnTester.of( + new CorrelatePurchaseCountForAddressesWithoutNamesFn( + purchasesTag, + addressesTag, + namesTag)) + .processBatch( + KV.of(1, result1), + KV.of(2, result2), + KV.of(3, result3), + KV.of(4, result4)); + assertThat(results, containsInAnyOrder(KV.of("4a", 2), KV.of("8a", 0))); + } + + /** + * Tests the pipeline end-to-end. Builds the purchases CoGroupByKey, and + * applies CorrelatePurchaseCountForAddressesWithoutNamesFn to the results. + */ + @SuppressWarnings("unchecked") + @Test + public void testCoGroupByKeyHandleResults() { + TupleTag namesTag = new TupleTag<>(); + TupleTag addressesTag = new TupleTag<>(); + TupleTag purchasesTag = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + + PCollection> coGbkResults = + buildPurchasesCoGbk(p, purchasesTag, addressesTag, namesTag); + + // Do some simple processing on the result of the CoGroupByKey. Count the + // purchases for each address on record that has no associated name. + PCollection> + purchaseCountByKnownAddressesWithoutKnownNames = + coGbkResults.apply(ParDo.of( + new CorrelatePurchaseCountForAddressesWithoutNamesFn( + purchasesTag, addressesTag, namesTag))); + + DataflowAssert.that(purchaseCountByKnownAddressesWithoutKnownNames) + .containsInAnyOrder( + KV.of("29 School Rd", 2), + KV.of("383 Jackson Street", 1)); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java new file mode 100644 index 000000000000..24e6dde65c4f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Serializer; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests the UnionCoder. + */ +@RunWith(JUnit4.class) +public class UnionCoderTest { + + @Test + public void testSerializationDeserialization() { + UnionCoder newCoder = + UnionCoder.of(Arrays.>asList(StringUtf8Coder.of(), + DoubleCoder.of())); + CloudObject encoding = newCoder.asCloudObject(); + Coder decodedCoder = Serializer.deserialize(encoding, Coder.class); + assertEquals(newCoder, decodedCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java new file mode 100644 index 000000000000..36028e493a75 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java @@ -0,0 +1,260 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.runWindowingFn; +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.set; +import static org.junit.Assert.assertEquals; + +import org.joda.time.DateTime; +import org.joda.time.DateTimeConstants; +import org.joda.time.DateTimeZone; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Tests for CalendarWindows WindowingFn. + */ +@RunWith(JUnit4.class) +public class CalendarWindowsTest { + + private static Instant makeTimestamp(int year, int month, int day, int hours, int minutes) { + return new DateTime(year, month, day, hours, minutes, DateTimeZone.UTC).toInstant(); + } + + @Test + public void testDays() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 1, 23, 59).getMillis(), + + makeTimestamp(2014, 1, 2, 0, 0).getMillis(), + makeTimestamp(2014, 1, 2, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 1, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 1, 2, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 2, 0, 0), + makeTimestamp(2014, 1, 3, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 1, 1, 0, 0), + makeTimestamp(2015, 1, 2, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowingFn(CalendarWindows.days(1), timestamps)); + } + + @Test + public void testWeeks() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 5, 5, 5).getMillis(), + + makeTimestamp(2014, 1, 8, 0, 0).getMillis(), + makeTimestamp(2014, 1, 12, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 6, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 1, 8, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 8, 0, 0), + makeTimestamp(2014, 1, 15, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 12, 31, 0, 0), + makeTimestamp(2015, 1, 7, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, + runWindowingFn(CalendarWindows.weeks(1, DateTimeConstants.WEDNESDAY), timestamps)); + } + + @Test + public void testMonths() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 31, 5, 5).getMillis(), + + makeTimestamp(2014, 2, 1, 0, 0).getMillis(), + makeTimestamp(2014, 2, 15, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 31, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 2, 1, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 2, 1, 0, 0), + makeTimestamp(2014, 3, 1, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 1, 1, 0, 0), + makeTimestamp(2015, 2, 1, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, + runWindowingFn(CalendarWindows.months(1), timestamps)); + } + + @Test + public void testMultiMonths() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 3, 5, 0, 0).getMillis(), + makeTimestamp(2014, 10, 4, 23, 59).getMillis(), + + makeTimestamp(2014, 10, 5, 0, 0).getMillis(), + makeTimestamp(2015, 3, 1, 0, 0).getMillis(), + + makeTimestamp(2016, 1, 5, 0, 0).getMillis(), + makeTimestamp(2016, 1, 31, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 3, 5, 0, 0), + makeTimestamp(2014, 10, 5, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 10, 5, 0, 0), + makeTimestamp(2015, 5, 5, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 12, 5, 0, 0), + makeTimestamp(2016, 7, 5, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowingFn( + CalendarWindows.months(7).withStartingMonth(2014, 3).beginningOnDay(5), timestamps)); + } + + @Test + public void testYears() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2000, 5, 5, 0, 0).getMillis(), + makeTimestamp(2010, 5, 4, 23, 59).getMillis(), + + makeTimestamp(2010, 5, 5, 0, 0).getMillis(), + makeTimestamp(2015, 3, 1, 0, 0).getMillis(), + + makeTimestamp(2052, 1, 5, 0, 0).getMillis(), + makeTimestamp(2060, 5, 4, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2000, 5, 5, 0, 0), + makeTimestamp(2010, 5, 5, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2010, 5, 5, 0, 0), + makeTimestamp(2020, 5, 5, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2050, 5, 5, 0, 0), + makeTimestamp(2060, 5, 5, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowingFn( + CalendarWindows.years(10).withStartingYear(2000).beginningOnDay(5, 5), timestamps)); + } + + @Test + public void testTimeZone() throws Exception { + Map> expected = new HashMap<>(); + + DateTimeZone timeZone = DateTimeZone.forID("America/Los_Angeles"); + + final List timestamps = Arrays.asList( + new DateTime(2014, 1, 1, 0, 0, timeZone).getMillis(), + new DateTime(2014, 1, 1, 23, 59, timeZone).getMillis(), + + new DateTime(2014, 1, 2, 8, 0, DateTimeZone.UTC).getMillis(), + new DateTime(2014, 1, 3, 7, 59, DateTimeZone.UTC).getMillis()); + + expected.put( + new IntervalWindow( + new DateTime(2014, 1, 1, 0, 0, timeZone).toInstant(), + new DateTime(2014, 1, 2, 0, 0, timeZone).toInstant()), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + new DateTime(2014, 1, 2, 0, 0, timeZone).toInstant(), + new DateTime(2014, 1, 3, 0, 0, timeZone).toInstant()), + set(timestamps.get(2), timestamps.get(3))); + + assertEquals(expected, runWindowingFn( + CalendarWindows.days(1).withTimeZone(timeZone), + timestamps)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java new file mode 100644 index 000000000000..0a68e72348f7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.runWindowingFn; +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.set; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Tests for FixedWindows WindowingFn. + */ +@RunWith(JUnit4.class) +public class FixedWindowsTest { + + @Test + public void testSimpleFixedWindow() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10, 11)); + expected.put(new IntervalWindow(new Instant(100), new Instant(110)), set(100)); + assertEquals( + expected, + runWindowingFn( + FixedWindows.of(new Duration(10)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testFixedOffsetWindow() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(5)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(5), new Instant(15)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(95), new Instant(105)), set(100)); + assertEquals( + expected, + runWindowingFn( + FixedWindows.of(new Duration(10)).withOffset(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5000), new Instant(5000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(15000)), set(5000, 5001, 10000)); + assertEquals( + expected, + runWindowingFn( + FixedWindows.of(Duration.standardSeconds(10)).withOffset(Duration.standardSeconds(5)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + void checkConstructionFailure(int size, int offset) { + try { + FixedWindows.of(Duration.standardSeconds(size)).withOffset(Duration.standardSeconds(offset)); + fail("should have failed"); + } catch (IllegalArgumentException e) { + assertThat(e.toString(), + containsString("FixedWindows WindowingStrategies must have 0 <= offset < size")); + } + } + + @Test + public void testInvalidInput() throws Exception { + checkConstructionFailure(-1, 0); + checkConstructionFailure(1, 2); + checkConstructionFailure(1, -1); + } + + @Test + public void testEquality() { + assertTrue(FixedWindows.of(new Duration(10)).isCompatible(FixedWindows.of(new Duration(10)))); + assertTrue( + FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(10)))); + assertTrue( + FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(10)))); + + assertFalse(FixedWindows.of(new Duration(10)).isCompatible(FixedWindows.of(new Duration(20)))); + assertFalse(FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(20)))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java new file mode 100644 index 000000000000..ccb1ddecc496 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.runWindowingFn; +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.set; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Tests for Sessions WindowingFn. + */ +@RunWith(JUnit4.class) +public class SessionsTest { + + @Test + public void testSimple() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(0)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10)); + expected.put(new IntervalWindow(new Instant(101), new Instant(111)), set(101)); + assertEquals( + expected, + runWindowingFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(0L, 10L, 101L))); + } + + @Test + public void testConsecutive() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(19)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(100), new Instant(111)), set(100, 101)); + assertEquals( + expected, + runWindowingFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(1L, 2L, 5L, 9L, 100L, 101L))); + } + + @Test + public void testMerging() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(40)), set(1, 10, 15, 22, 30)); + expected.put(new IntervalWindow(new Instant(95), new Instant(111)), set(95, 100, 101)); + assertEquals( + expected, + runWindowingFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(1L, 15L, 30L, 100L, 101L, 95L, 22L, 10L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(2000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(6001)), set(5000, 5001)); + expected.put(new IntervalWindow(new Instant(10000), new Instant(11000)), set(10000)); + assertEquals( + expected, + runWindowingFn( + Sessions.withGapDuration(Duration.standardSeconds(1)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + @Test + public void testEquality() { + assertTrue( + Sessions.withGapDuration(new Duration(10)).isCompatible( + Sessions.withGapDuration(new Duration(10)))); + assertTrue( + Sessions.withGapDuration(new Duration(10)).isCompatible( + Sessions.withGapDuration(new Duration(20)))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java new file mode 100644 index 000000000000..f187cb429940 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.runWindowingFn; +import static com.google.cloud.dataflow.sdk.testing.WindowingFnTestUtils.set; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Tests for the SlidingWindows WindowingFn. + */ +@RunWith(JUnit4.class) +public class SlidingWindowsTest { + + @Test + public void testSimple() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(5)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(5), new Instant(15)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10, 11)); + assertEquals( + expected, + runWindowingFn( + SlidingWindows.of(new Duration(10)).every(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testSlightlyOverlapping() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(2)), set(1)); + expected.put(new IntervalWindow(new Instant(0), new Instant(7)), set(1, 2, 5)); + expected.put(new IntervalWindow(new Instant(5), new Instant(12)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(10), new Instant(17)), set(10, 11)); + assertEquals( + expected, + runWindowingFn( + SlidingWindows.of(new Duration(7)).every(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testElidings() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(3)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(10), new Instant(13)), set(10, 11)); + expected.put(new IntervalWindow(new Instant(100), new Instant(103)), set(100)); + assertEquals( + expected, + runWindowingFn( + // Only look at the first 3 millisecs of every 10-millisec interval. + SlidingWindows.of(new Duration(3)).every(new Duration(10)), + Arrays.asList(1L, 2L, 3L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testOffset() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-8), new Instant(2)), set(1)); + expected.put(new IntervalWindow(new Instant(-3), new Instant(7)), set(1, 2, 5)); + expected.put(new IntervalWindow(new Instant(2), new Instant(12)), set(2, 5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(7), new Instant(17)), set(9, 10, 11)); + assertEquals( + expected, + runWindowingFn( + SlidingWindows.of(new Duration(10)).every(new Duration(5)).withOffset(new Duration(2)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5000), new Instant(5000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(0), new Instant(10000)), + set(1, 2, 1000, 5000, 5001)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(15000)), set(5000, 5001, 10000)); + expected.put(new IntervalWindow(new Instant(10000), new Instant(20000)), set(10000)); + assertEquals( + expected, + runWindowingFn( + SlidingWindows.of(Duration.standardSeconds(10)).every(Duration.standardSeconds(5)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + @Test + public void testEquality() { + assertTrue( + SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(10)))); + assertTrue( + SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(10)))); + + assertFalse(SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(20)))); + assertFalse(SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(20)))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java new file mode 100644 index 000000000000..85c0bf6b8b6d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.Arrays; + +/** Unit tests for bucketing. */ +@RunWith(JUnit4.class) +public class WindowingTest implements Serializable { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private static class WindowedCount extends PTransform, PCollection> { + private WindowingFn windowingFn; + public WindowedCount(WindowingFn windowingFn) { + this.windowingFn = (WindowingFn) windowingFn; + } + @Override + public PCollection apply(PCollection in) { + return in + .apply(Window.named("Window").into(windowingFn)) + .apply(Count.perElement()) + .apply(ParDo + .named("FormatCounts") + .of(new DoFn, String>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + ":" + c.element().getValue() + + ":" + c.timestamp().getMillis() + ":" + c.windows()); + } + })) + .setCoder(StringUtf8Coder.of()); + } + } + + private String output(String value, int count, int timestamp, int windowStart, int windowEnd) { + return value + ":" + count + ":" + timestamp + + ":[[" + new Instant(windowStart) + ".." + new Instant(windowEnd) + ")]"; + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testPartitioningWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("b", new Instant(2)), + TimestampedValue.of("b", new Instant(3)), + TimestampedValue.of("c", new Instant(11)), + TimestampedValue.of("d", new Instant(11)))); + + PCollection output = + input + .apply(new WindowedCount(FixedWindows.of(new Duration(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 9, 0, 10), + output("b", 2, 9, 0, 10), + output("c", 1, 19, 10, 20), + output("d", 1, 19, 10, 20)); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testNonPartitioningWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(7)), + TimestampedValue.of("b", new Instant(8)))); + + PCollection output = + input + .apply(new WindowedCount( + SlidingWindows.of(new Duration(10)).every(new Duration(5)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 4, -5, 5), + output("a", 2, 9, 0, 10), + output("a", 1, 14, 5, 15), + output("b", 1, 9, 0, 10), + output("b", 1, 14, 5, 15)); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testMergingWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(5)), + TimestampedValue.of("a", new Instant(20)))); + + PCollection output = + input + .apply(new WindowedCount(Sessions.withGapDuration(new Duration(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 2, 14, 1, 15), + output("a", 1, 29, 20, 30)); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testWindowPreservation() { + Pipeline p = TestPipeline.create(); + PCollection input1 = p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("b", new Instant(2)))); + + PCollection input2 = p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(3)), + TimestampedValue.of("b", new Instant(4)))); + + PCollectionList input = PCollectionList.of(input1).and(input2); + + PCollection output = + input + .apply(Flatten.create()) + .apply(new WindowedCount(FixedWindows.of(new Duration(5)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 2, 4, 0, 5), + output("b", 2, 4, 0, 5)); + + p.run(); + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testElementsSortedByTimestamp() { + // The Windowing API does not guarantee that elements will be sorted by + // timestamp, but the implementation currently relies on this, so it + // needs to be tested. + + Pipeline p = TestPipeline.create(); + + PCollection> a = p + .apply(Create.timestamped( + TimestampedValue.of(KV.of("k", "a"), new Instant(1)), + TimestampedValue.of(KV.of("k", "b"), new Instant(4)), + TimestampedValue.of(KV.of("k", "c"), new Instant(3)), + TimestampedValue.of(KV.of("k", "d"), new Instant(5)), + TimestampedValue.of(KV.of("k", "e"), new Instant(2)), + TimestampedValue.of(KV.of("k", "f"), new Instant(-5)), + TimestampedValue.of(KV.of("k", "g"), new Instant(-6)), + TimestampedValue.of(KV.of("k", "h"), new Instant(-255)), + TimestampedValue.of(KV.of("k", "i"), new Instant(-256)), + TimestampedValue.of(KV.of("k", "j"), new Instant(255)))); + + PCollection> b = a + .apply(Window.>into( + FixedWindows.of(new Duration(1000)).withOffset(new Duration(500)))); + + PCollection>> output = b + .apply(GroupByKey.create()); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("k", + (Iterable) Arrays.asList("i", "h", "g", "f", "a", "e", "c", "b", "d", "j"))); + + p.run(); + } + + @Test + public void testEmptyInput() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.timestamped()) + .setCoder(StringUtf8Coder.of()); + + PCollection output = + input + .apply(new WindowedCount(FixedWindows.of(new Duration(10)))); + + DataflowAssert.that(output).containsInAnyOrder(); + + p.run(); + } + + @Test + public void testTextIoInput() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + try (PrintStream writer = new PrintStream(new FileOutputStream(tmpFile))) { + writer.println("a 1"); + writer.println("b 2"); + writer.println("b 3"); + writer.println("c 11"); + writer.println("d 11"); + } + + Pipeline p = TestPipeline.create(); + PCollection output = p.begin() + .apply(TextIO.Read.named("ReadLines").from(filename)) + .apply(ParDo.of(new ExtractWordsWithTimestampsFn())) + .apply(new WindowedCount(FixedWindows.of(Duration.millis(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 9, 0, 10), + output("b", 2, 9, 0, 10), + output("c", 1, 19, 10, 20), + output("d", 1, 19, 10, 20)); + + p.run(); + } + + /** A DoFn that tokenizes lines of text into individual words. */ + static class ExtractWordsWithTimestampsFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + String[] words = c.element().split("[^a-zA-Z0-9']+"); + if (words.length == 2) { + c.outputWithTimestamp(words[0], new Instant(Long.parseLong(words[1]))); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AggregatorImplTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AggregatorImplTest.java new file mode 100644 index 000000000000..45cc267d2d74 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AggregatorImplTest.java @@ -0,0 +1,194 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MIN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterTestUtils; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for the {@link Aggregator} API. + */ +@RunWith(JUnit4.class) +public class AggregatorImplTest { + @Rule + public final ExpectedException expectedEx = ExpectedException.none(); + + private static final String AGGREGATOR_NAME = "aggregator_name"; + + private void testAggregator(List items, + SerializableFunction, V> combiner, + Counter expectedCounter) { + CounterSet counters = new CounterSet(); + Aggregator aggregator = new AggregatorImpl, V>( + AGGREGATOR_NAME, combiner, counters.getAddCounterMutator()); + for (V item : items) { + aggregator.addValue(item); + } + + List cloudCounterSet = CounterTestUtils.extractCounterUpdates(counters, false); + Assert.assertEquals(cloudCounterSet.size(), 1); + Assert.assertEquals(cloudCounterSet.get(0), + CounterTestUtils.extractCounterUpdate(expectedCounter, false)); + } + + @Test + public void testSumInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Sum.SumIntegerFn(), + Counter.ints(AGGREGATOR_NAME, SUM).resetToValue(10)); + } + + @Test + public void testSumLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Sum.SumLongFn(), + Counter.longs(AGGREGATOR_NAME, SUM).resetToValue(10L)); + } + + @Test + public void testSumDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Sum.SumDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, SUM).resetToValue(10.2)); + } + + @Test + public void testMinInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Min.MinIntegerFn(), + Counter.ints(AGGREGATOR_NAME, MIN).resetToValue(1)); + } + + @Test + public void testMinLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Min.MinLongFn(), + Counter.longs(AGGREGATOR_NAME, MIN).resetToValue(1L)); + } + + @Test + public void testMinDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Min.MinDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, MIN).resetToValue(1.0)); + } + + @Test + public void testMaxInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Max.MaxIntegerFn(), + Counter.ints(AGGREGATOR_NAME, MAX).resetToValue(4)); + } + + @Test + public void testMaxLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Max.MaxLongFn(), + Counter.longs(AGGREGATOR_NAME, MAX).resetToValue(4L)); + } + + @Test + public void testMaxDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Max.MaxDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, MAX).resetToValue(4.1)); + } + + @Test + public void testCompatibleDuplicateNames() throws Exception { + CounterSet counters = new CounterSet(); + Aggregator aggregator1 = + new AggregatorImpl, Integer>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + Aggregator aggregator2 = + new AggregatorImpl, Integer>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + // The duplicate aggregators should update the same counter. + aggregator1.addValue(3); + aggregator2.addValue(4); + Assert.assertEquals( + new CounterSet(Counter.ints(AGGREGATOR_NAME, SUM).resetToValue(7)), + counters); + } + + @Test + public void testIncompatibleDuplicateNames() throws Exception { + CounterSet counters = new CounterSet(); + new AggregatorImpl, Integer>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString( + "aggregator's name collides with an existing aggregator or " + + "system-provided counter of an incompatible type")); + new AggregatorImpl, Long>( + AGGREGATOR_NAME, new Sum.SumLongFn(), + counters.getAddCounterMutator()); + } + + @Test + public void testUnsupportedCombineFn() throws Exception { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString("unsupported combiner")); + new AggregatorImpl<>( + AGGREGATOR_NAME, + new Combine.CombineFn, Integer>() { + @Override + public List createAccumulator() { return null; } + @Override + public void addInput(List accumulator, Integer input) { } + @Override + public List mergeAccumulators(Iterable> accumulators) { + return null; } + @Override + public Integer extractOutput(List accumulator) { return null; } + }, + (new CounterSet()).getAddCounterMutator()); + } + + @Test + public void testUnsupportedSerializableFunction() throws Exception { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString("unsupported combiner")); + new AggregatorImpl, Integer>( + AGGREGATOR_NAME, + new SerializableFunction, Integer>() { + @Override + public Integer apply(Iterable input) { return null; } + }, + (new CounterSet()).getAddCounterMutator()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java new file mode 100644 index 000000000000..0c262e2f1cb0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertEquals; + +import com.google.api.client.util.BackOff; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AttemptBoundedExponentialBackOff}. */ +@RunWith(JUnit4.class) +public class AttemptBoundedExponentialBackOffTest { + @Rule public ExpectedException exception = ExpectedException.none(); + + @Test + public void testUsingInvalidInitialInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Initial interval must be greater than zero."); + new AttemptBoundedExponentialBackOff(10, 0L); + } + + @Test + public void testUsingInvalidMaximumNumberOfRetries() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Maximum number of attempts must be greater than zero."); + new AttemptBoundedExponentialBackOff(-1, 10L); + } + + @Test + public void testThatFixedNumberOfAttemptsExits() throws Exception { + BackOff backOff = new AttemptBoundedExponentialBackOff(3, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testThatResettingAllowsReuse() throws Exception { + BackOff backOff = new AttemptBoundedExponentialBackOff(3, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/Base64UtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/Base64UtilsTest.java new file mode 100644 index 000000000000..d557284ce080 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/Base64UtilsTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.api.client.util.Base64.encodeBase64URLSafeString; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link Base64Utils}. */ +@RunWith(JUnit4.class) +public class Base64UtilsTest { + void testLength(int length) { + byte[] b = new byte[length]; + // Make sure that the estimated length is an upper bound. + assertThat( + Base64Utils.getBase64Length(length), + greaterThanOrEqualTo(encodeBase64URLSafeString(b).length())); + // Make sure that it's a tight upper bound (no more than 4 characters off). + assertThat( + Base64Utils.getBase64Length(length), + lessThan(4 + encodeBase64URLSafeString(b).length())); + } + + @Test + public void getBase64Length() { + for (int i = 0; i < 100; ++i) { + testLength(i); + } + for (int i = 1000; i < 1100; ++i) { + testLength(i); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java new file mode 100644 index 000000000000..ca75e6f94ca7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java @@ -0,0 +1,306 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.common.base.Function; +import com.google.common.collect.Iterators; + +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for util classes related to BigQuery. + */ +@RunWith(JUnit4.class) +public class BigQueryUtilTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock private Bigquery mockClient; + @Mock private Bigquery.Tables mockTables; + @Mock private Bigquery.Tables.Get mockTablesGet; + @Mock private Bigquery.Tabledata mockTabledata; + @Mock private Bigquery.Tabledata.List mockTabledataList; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockClient); + verifyNoMoreInteractions(mockTables); + verifyNoMoreInteractions(mockTablesGet); + verifyNoMoreInteractions(mockTabledata); + verifyNoMoreInteractions(mockTabledataList); + } + + private void onTableGet(Table table) throws IOException { + when(mockClient.tables()) + .thenReturn(mockTables); + when(mockTables.get(anyString(), anyString(), anyString())) + .thenReturn(mockTablesGet); + when(mockTablesGet.execute()) + .thenReturn(table); + } + + private void verifyTableGet() throws IOException { + verify(mockClient).tables(); + verify(mockTables).get("project", "dataset", "table"); + verify(mockTablesGet).execute(); + } + + private void onTableList(TableDataList result) throws IOException { + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())) + .thenReturn(mockTabledataList); + when(mockTabledataList.execute()) + .thenReturn(result); + } + + private void verifyTabledataList() throws IOException { + verify(mockClient, atLeastOnce()).tabledata(); + verify(mockTabledata, atLeastOnce()).list("project", "dataset", "table"); + verify(mockTabledataList, atLeastOnce()).execute(); + // Max results may be set when testing for an empty table. + verify(mockTabledataList, atLeast(0)).setMaxResults(anyLong()); + } + + private Table basicTableSchema() { + return new Table() + .setSchema(new TableSchema() + .setFields(Arrays.asList( + new TableFieldSchema() + .setName("name") + .setType("STRING"), + new TableFieldSchema() + .setName("answer") + .setType("INTEGER") + ))); + } + + private TableRow rawRow(Object...args) { + List cells = new LinkedList<>(); + for (Object a : args) { + cells.add(new TableCell().setV(a)); + } + return new TableRow().setF(cells); + } + + private TableDataList rawDataList(TableRow...rows) { + return new TableDataList() + .setRows(Arrays.asList(rows)); + } + + @Test + public void testRead() throws IOException { + onTableGet(basicTableSchema()); + + TableDataList dataList = rawDataList(rawRow("Arthur", 42)); + onTableList(dataList); + + BigQueryTableRowIterator iterator = new BigQueryTableRowIterator( + mockClient, + BigQueryIO.parseTableSpec("project:dataset.table")); + + Assert.assertTrue(iterator.hasNext()); + TableRow row = iterator.next(); + + Assert.assertTrue(row.containsKey("name")); + Assert.assertTrue(row.containsKey("answer")); + Assert.assertEquals("Arthur", row.get("name")); + Assert.assertEquals(42, row.get("answer")); + + Assert.assertFalse(iterator.hasNext()); + + verifyTableGet(); + verifyTabledataList(); + } + + @Test + public void testReadEmpty() throws IOException { + onTableGet(basicTableSchema()); + + // BigQuery may respond with a page token for an empty table, ensure we + // handle it. + TableDataList dataList = new TableDataList() + .setPageToken("FEED==") + .setTotalRows(0L); + onTableList(dataList); + + BigQueryTableRowIterator iterator = new BigQueryTableRowIterator( + mockClient, + BigQueryIO.parseTableSpec("project:dataset.table")); + + Assert.assertFalse(iterator.hasNext()); + + verifyTableGet(); + verifyTabledataList(); + } + + @Test + public void testReadMultiPage() throws IOException { + onTableGet(basicTableSchema()); + + TableDataList page1 = rawDataList(rawRow("Row1", 1)) + .setPageToken("page2"); + TableDataList page2 = rawDataList(rawRow("Row2", 2)) + .setTotalRows(2L); + + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())) + .thenReturn(mockTabledataList); + when(mockTabledataList.execute()) + .thenReturn(page1) + .thenReturn(page2); + + BigQueryTableRowIterator iterator = new BigQueryTableRowIterator( + mockClient, + BigQueryIO.parseTableSpec("project:dataset.table")); + List names = new LinkedList<>(); + Iterators.addAll(names, + Iterators.transform(iterator, new Function(){ + @Override + public String apply(TableRow input) { + return (String) input.get("name"); + } + })); + + Assert.assertThat(names, Matchers.hasItems("Row1", "Row2")); + + verifyTableGet(); + verifyTabledataList(); + // The second call should have used a page token. + verify(mockTabledataList).setPageToken("page2"); + } + + @Test + public void testReadOpenFailure() throws IOException { + thrown.expect(RuntimeException.class); + + when(mockClient.tables()) + .thenReturn(mockTables); + when(mockTables.get(anyString(), anyString(), anyString())) + .thenReturn(mockTablesGet); + when(mockTablesGet.execute()) + .thenThrow(new IOException("No such table")); + + BigQueryTableRowIterator iterator = new BigQueryTableRowIterator( + mockClient, + BigQueryIO.parseTableSpec("project:dataset.table")); + try { + Assert.assertFalse(iterator.hasNext()); // throws. + } finally { + verifyTableGet(); + } + } + + @Test + public void testWriteAppend() throws IOException { + onTableGet(basicTableSchema()); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = + new BigQueryTableInserter(mockClient, ref); + + inserter.getOrCreateTable(BigQueryIO.Write.WriteDisposition.WRITE_APPEND, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + + verifyTableGet(); + } + + @Test + public void testWriteEmpty() throws IOException { + onTableGet(basicTableSchema()); + + TableDataList dataList = new TableDataList().setTotalRows(0L); + onTableList(dataList); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = + new BigQueryTableInserter(mockClient, ref); + + inserter.getOrCreateTable(BigQueryIO.Write.WriteDisposition.WRITE_EMPTY, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + + verifyTableGet(); + verifyTabledataList(); + } + + @Test + public void testWriteEmptyFail() throws IOException { + thrown.expect(IOException.class); + + onTableGet(basicTableSchema()); + + TableDataList dataList = rawDataList(rawRow("Arthur", 42)); + onTableList(dataList); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = + new BigQueryTableInserter(mockClient, ref); + + try { + inserter.getOrCreateTable(BigQueryIO.Write.WriteDisposition.WRITE_EMPTY, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + } finally { + verifyTableGet(); + verifyTabledataList(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtilsTest.java new file mode 100644 index 000000000000..31bc2f9241ea --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudMetricUtilsTest.java @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.util.common.Metric; +import com.google.cloud.dataflow.sdk.util.common.Metric.DoubleMetric; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Unit tests for {@link CloudMetricUtils}. */ +@RunWith(JUnit4.class) +public class CloudMetricUtilsTest { + private void addDoubleMetric(String name, double value, String workerId, + List> metrics, + List cloudMetrics) { + metrics.add(new DoubleMetric(name, value)); + MetricStructuredName structuredName = new MetricStructuredName(); + structuredName.setName(name); + Map context = new HashMap<>(); + context.put("workerId", workerId); + structuredName.setContext(context); + cloudMetrics.add(new MetricUpdate() + .setName(structuredName) + .setScalar(CloudObject.forFloat(value))); + } + + @Test + public void testExtractCloudMetrics() { + List> metrics = new ArrayList<>(); + List expected = new ArrayList<>(); + String workerId = "worker-id"; + + addDoubleMetric("m1", 3.14, workerId, metrics, expected); + addDoubleMetric("m2", 2.17, workerId, metrics, expected); + addDoubleMetric("m3", -66.666, workerId, metrics, expected); + + List actual = CloudMetricUtils.extractCloudMetrics(metrics, workerId); + + assertEquals(expected, actual); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtilsTest.java new file mode 100644 index 000000000000..d813a103fabb --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CloudSourceUtilsTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.api.services.dataflow.model.Source; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Map; + +/** + * Tests for {@code CloudSourceUtils}. + */ +@RunWith(JUnit4.class) +public class CloudSourceUtilsTest { + @Test + public void testFlattenBaseSpecs() throws Exception { + // G = grandparent, P = parent, C = child. + CloudObject grandparent = CloudObject.forClassName("text"); + addString(grandparent, "G", "g_g"); + addString(grandparent, "GP", "gp_g"); + addString(grandparent, "GC", "gc_g"); + addString(grandparent, "GPC", "gpc_g"); + + CloudObject parent = CloudObject.forClassName("text"); + addString(parent, "P", "p_p"); + addString(parent, "PC", "pc_p"); + addString(parent, "GP", "gp_p"); + addString(parent, "GPC", "gpc_p"); + + CloudObject child = CloudObject.forClassName("text"); + addString(child, "C", "c_c"); + addString(child, "PC", "pc_c"); + addString(child, "GC", "gc_c"); + addString(child, "GPC", "gpc_c"); + + Source source = new Source(); + source.setBaseSpecs(new ArrayList>()); + source.getBaseSpecs().add(grandparent); + source.getBaseSpecs().add(parent); + source.setSpec(child); + source.setCodec(makeCloudEncoding(StringUtf8Coder.class.getName())); + + Source flat = CloudSourceUtils.flattenBaseSpecs(source); + assertNull(flat.getBaseSpecs()); + assertEquals( + StringUtf8Coder.class.getName(), + getString(flat.getCodec(), PropertyNames.OBJECT_TYPE_NAME)); + + CloudObject flatSpec = CloudObject.fromSpec(flat.getSpec()); + assertEquals("g_g", getString(flatSpec, "G")); + assertEquals("p_p", getString(flatSpec, "P")); + assertEquals("c_c", getString(flatSpec, "C")); + assertEquals("gp_p", getString(flatSpec, "GP")); + assertEquals("gc_c", getString(flatSpec, "GC")); + assertEquals("pc_c", getString(flatSpec, "PC")); + assertEquals("gpc_c", getString(flatSpec, "GPC")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java new file mode 100644 index 000000000000..92f9e7481558 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Tests for CoderUtils. + */ +@RunWith(JUnit4.class) +public class CoderUtilsTest { + static class TestCoder extends AtomicCoder { + public static TestCoder of() { return new TestCoder(); } + + @Override + public void encode(Integer value, OutputStream outStream, Context context) { + throw new RuntimeException("not expecting to be called"); + } + + @Override + public Integer decode(InputStream inStream, Context context) { + throw new RuntimeException("not expecting to be called"); + } + + @Override + public boolean isDeterministic() { + return false; + } + } + + @Test + public void testCreateAtomicCoders() throws Exception { + Assert.assertEquals( + BigEndianIntegerCoder.of(), + Serializer.deserialize(makeCloudEncoding("BigEndianIntegerCoder"), Coder.class)); + Assert.assertEquals( + StringUtf8Coder.of(), + Serializer.deserialize( + makeCloudEncoding(StringUtf8Coder.class.getName()), Coder.class)); + Assert.assertEquals( + VoidCoder.of(), + Serializer.deserialize(makeCloudEncoding("VoidCoder"), Coder.class)); + Assert.assertEquals( + TestCoder.of(), + Serializer.deserialize(makeCloudEncoding(TestCoder.class.getName()), Coder.class)); + } + + @Test + public void testCreateCompositeCoders() throws Exception { + Assert.assertEquals( + IterableCoder.of(StringUtf8Coder.of()), + Serializer.deserialize( + makeCloudEncoding("IterableCoder", + makeCloudEncoding("StringUtf8Coder")), Coder.class)); + Assert.assertEquals( + KvCoder.of(BigEndianIntegerCoder.of(), VoidCoder.of()), + Serializer.deserialize( + makeCloudEncoding( + "KvCoder", + makeCloudEncoding(BigEndianIntegerCoder.class.getName()), + makeCloudEncoding("VoidCoder")), Coder.class)); + Assert.assertEquals( + IterableCoder.of( + KvCoder.of(IterableCoder.of(BigEndianIntegerCoder.of()), + KvCoder.of(VoidCoder.of(), + TestCoder.of()))), + Serializer.deserialize( + makeCloudEncoding( + IterableCoder.class.getName(), + makeCloudEncoding( + KvCoder.class.getName(), + makeCloudEncoding( + "IterableCoder", + makeCloudEncoding("BigEndianIntegerCoder")), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("VoidCoder"), + makeCloudEncoding(TestCoder.class.getName())))), Coder.class)); + } + + @Test + public void testCreateUntypedCoders() throws Exception { + Assert.assertEquals( + IterableCoder.of(StringUtf8Coder.of()), + Serializer.deserialize( + makeCloudEncoding( + "kind:stream", + makeCloudEncoding("StringUtf8Coder")), Coder.class)); + Assert.assertEquals( + KvCoder.of(BigEndianIntegerCoder.of(), VoidCoder.of()), + Serializer.deserialize( + makeCloudEncoding( + "kind:pair", + makeCloudEncoding(BigEndianIntegerCoder.class.getName()), + makeCloudEncoding("VoidCoder")), Coder.class)); + Assert.assertEquals( + IterableCoder.of( + KvCoder.of(IterableCoder.of(BigEndianIntegerCoder.of()), + KvCoder.of(VoidCoder.of(), + TestCoder.of()))), + Serializer.deserialize( + makeCloudEncoding( + "kind:stream", + makeCloudEncoding( + "kind:pair", + makeCloudEncoding( + "kind:stream", + makeCloudEncoding("BigEndianIntegerCoder")), + makeCloudEncoding( + "kind:pair", + makeCloudEncoding("VoidCoder"), + makeCloudEncoding(TestCoder.class.getName())))), Coder.class)); + } + + @Test + public void testCreateUnknownCoder() throws Exception { + try { + Serializer.deserialize(makeCloudEncoding("UnknownCoder"), Coder.class); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "Unable to convert coder ID UnknownCoder to class")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java new file mode 100644 index 000000000000..cae705cea5a4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.util.Throwables; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +/** Test case for {@link GcsUtil}. */ +@RunWith(JUnit4.class) +public class GcsUtilTest { + @Test + public void testGlobTranslation() { + assertEquals("foo", GcsUtil.globToRegexp("foo")); + assertEquals("fo[^/]*o", GcsUtil.globToRegexp("fo*o")); + assertEquals("f[^/]*o\\.[^/]", GcsUtil.globToRegexp("f*o.?")); + assertEquals("foo-[0-9][^/]*", GcsUtil.globToRegexp("foo-[0-9]*")); + } + + @Test + public void testCreationWithDefaultOptions() { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGcpCredential(Mockito.mock(Credential.class)); + assertNotNull(pipelineOptions.getGcpCredential()); + } + + @Test + public void testCreationWithExecutorServiceProvided() { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGcpCredential(Mockito.mock(Credential.class)); + pipelineOptions.setExecutorService(Executors.newCachedThreadPool()); + assertSame(pipelineOptions.getExecutorService(), pipelineOptions.getGcsUtil().executorService); + } + + @Test + public void testCreationWithGcsUtilProvided() { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + GcsUtil gcsUtil = Mockito.mock(GcsUtil.class); + pipelineOptions.setGcsUtil(gcsUtil); + assertSame(gcsUtil, pipelineOptions.getGcsUtil()); + } + + @Test + public void testMultipleThreadsCanCompleteOutOfOrderWithDefaultThreadPool() throws Exception { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + ExecutorService executorService = pipelineOptions.getExecutorService(); + + int numThreads = 1000; + final CountDownLatch[] countDownLatches = new CountDownLatch[numThreads]; + for (int i = 0; i < numThreads; i++) { + final int currentLatch = i; + countDownLatches[i] = new CountDownLatch(1); + executorService.execute(new Runnable() { + @Override + public void run() { + // Wait for latch N and then release latch N - 1 + try { + countDownLatches[currentLatch].await(); + if (currentLatch > 0) { + countDownLatches[currentLatch - 1].countDown(); + } + } catch (InterruptedException e) { + throw Throwables.propagate(e); + } + } + }); + } + + // Release the last latch starting the chain reaction. + countDownLatches[countDownLatches.length - 1].countDown(); + executorService.shutdown(); + assertTrue("Expected tasks to complete", + executorService.awaitTermination(10, TimeUnit.SECONDS)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFnTest.java new file mode 100644 index 000000000000..d482d2c4d345 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFnTest.java @@ -0,0 +1,231 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link GroupAlsoByWindowsDoFn}. */ +@RunWith(JUnit4.class) +public class GroupAlsoByWindowsDoFnTest { + ExecutionContext execContext; + CounterSet counters; + TupleTag>> outputTag; + + @Before public void setUp() { + execContext = new DirectModeExecutionContext(); + counters = new CounterSet(); + outputTag = new TupleTag<>(); + } + + @Test public void testEmpty() throws Exception { + DoFnRunner>>, + KV>, List> runner = + makeRunner(FixedWindows.of(Duration.millis(10))); + + runner.startBundle(); + + runner.finishBundle(); + + List>> result = runner.getReceiver(outputTag); + + assertEquals(0, result.size()); + } + + @Test public void testFixedWindows() throws Exception { + DoFnRunner>>, + KV>, List> runner = + makeRunner(FixedWindows.of(Duration.millis(10))); + + runner.startBundle(); + + runner.processElement(WindowedValue.valueInEmptyWindows( + KV.of("k", (Iterable>) Arrays.asList( + WindowedValue.of( + "v1", + new Instant(1), + Arrays.asList(window(0, 10))), + WindowedValue.of( + "v2", + new Instant(2), + Arrays.asList(window(0, 10))), + WindowedValue.of( + "v3", + new Instant(13), + Arrays.asList(window(10, 20))))))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(2, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.contains("v1", "v2")); + assertEquals(new Instant(9), item0.getTimestamp()); + assertThat(item0.getWindows(), + Matchers.contains(window(0, 10))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.contains("v3")); + assertEquals(new Instant(19), item1.getTimestamp()); + assertThat(item1.getWindows(), + Matchers.contains(window(10, 20))); + } + + @Test public void testSlidingWindows() throws Exception { + DoFnRunner>>, + KV>, List> runner = + makeRunner(SlidingWindows.of(Duration.millis(20)).every(Duration.millis(10))); + + runner.startBundle(); + + runner.processElement(WindowedValue.valueInEmptyWindows( + KV.of("k", (Iterable>) Arrays.asList( + WindowedValue.of( + "v1", + new Instant(5), + Arrays.asList(window(-10, 10), window(0, 20))), + WindowedValue.of( + "v2", + new Instant(15), + Arrays.asList(window(0, 20), window(10, 30))))))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(3, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.contains("v1")); + assertEquals(new Instant(9), item0.getTimestamp()); + assertThat(item0.getWindows(), + Matchers.contains(window(-10, 10))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.contains("v1", "v2")); + assertEquals(new Instant(19), item1.getTimestamp()); + assertThat(item1.getWindows(), + Matchers.contains(window(0, 20))); + + WindowedValue>> item2 = result.get(2); + assertEquals("k", item2.getValue().getKey()); + assertThat(item2.getValue().getValue(), Matchers.contains("v2")); + assertEquals(new Instant(29), item2.getTimestamp()); + assertThat(item2.getWindows(), + Matchers.contains(window(10, 30))); + } + + @Test public void testSessions() throws Exception { + DoFnRunner>>, + KV>, List> runner = + makeRunner(Sessions.withGapDuration(Duration.millis(10))); + + runner.startBundle(); + + runner.processElement(WindowedValue.valueInEmptyWindows( + KV.of("k", (Iterable>) Arrays.asList( + WindowedValue.of( + "v1", + new Instant(0), + Arrays.asList(window(0, 10))), + WindowedValue.of( + "v2", + new Instant(5), + Arrays.asList(window(5, 15))), + WindowedValue.of( + "v3", + new Instant(15), + Arrays.asList(window(15, 25))))))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(2, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.contains("v1", "v2")); + assertEquals(new Instant(14), item0.getTimestamp()); + assertThat(item0.getWindows(), + Matchers.contains(window(0, 15))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.contains("v3")); + assertEquals(new Instant(24), item1.getTimestamp()); + assertThat(item1.getWindows(), + Matchers.contains(window(15, 25))); + } + + + private DoFnRunner>>, + KV>, List> makeRunner( + WindowingFn windowingFn) { + + GroupAlsoByWindowsDoFn fn = + new GroupAlsoByWindowsDoFn( + windowingFn, StringUtf8Coder.of()); + + DoFnRunner>>, + KV>, List> runner = + DoFnRunner.createWithListOutputs( + PipelineOptionsFactory.create(), + fn, + PTuple.empty(), + outputTag, + new ArrayList>(), + execContext.createStepContext("merge"), + counters.getAddCounterMutator()); + + return runner; + } + + private BoundedWindow window(long start, long end) { + return new IntervalWindow(new Instant(start), new Instant(end)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java new file mode 100644 index 000000000000..fe82972044d0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.channels.WritableByteChannel; + +/** + * Tests for IOChannelUtils. + */ +@RunWith(JUnit4.class) +public class IOChannelUtilsTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testShardFormatExpansion() { + Assert.assertEquals("output-001-of-123.txt", + IOChannelUtils.constructName("output", "-SSS-of-NNN", + ".txt", + 1, 123)); + + Assert.assertEquals("out.txt/part-00042", + IOChannelUtils.constructName("out.txt", "/part-SSSSS", "", + 42, 100)); + + Assert.assertEquals("out.txt", + IOChannelUtils.constructName("ou", "t.t", "xt", 1, 1)); + + Assert.assertEquals("out0102shard.txt", + IOChannelUtils.constructName("out", "SSNNshard", ".txt", 1, 2)); + + Assert.assertEquals("out-2/1.part-1-of-2.txt", + IOChannelUtils.constructName("out", "-N/S.part-S-of-N", + ".txt", 1, 2)); + } + + @Test(expected = IllegalArgumentException.class) + public void testShardNameCollision() throws Exception { + File outFolder = tmpFolder.newFolder(); + String filename = outFolder.toPath().resolve("output").toString(); + + WritableByteChannel output = IOChannelUtils + .create(filename, "", "", 2, "text"); + Assert.fail("IOChannelUtils.create expected to fail due " + + "to filename collision"); + } + + @Test + public void testLargeShardCount() { + Assert.assertEquals("out-100-of-5000.txt", + IOChannelUtils.constructName("out", "-SS-of-NN", ".txt", + 100, 5000)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOFactoryTest.java new file mode 100644 index 000000000000..fbf2f70b2235 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOFactoryTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.worker.TextSource; +import com.google.cloud.dataflow.sdk.util.common.worker.Source; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.util.Collection; +import java.util.Set; +import java.util.TreeSet; + +/** + * Tests for IOFactory. + */ +@RunWith(JUnit4.class) +public class IOFactoryTest { + + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testLocalFileIO() throws Exception { + // Create some files to match against. + File foo1 = tmpFolder.newFile("foo1"); + foo1.createNewFile(); + File foo2 = tmpFolder.newFile("foo2"); + foo2.createNewFile(); + tmpFolder.newFile("barf").createNewFile(); + + FileIOChannelFactory factory = new FileIOChannelFactory(); + Collection paths = factory.match(tmpFolder.getRoot() + "/f*"); + + Assert.assertEquals(2, paths.size()); + Assert.assertTrue(paths.contains(foo1.getCanonicalPath())); + Assert.assertTrue(paths.contains(foo2.getCanonicalPath())); + } + + @Test + public void testMultiFileRead() throws Exception { + File file1 = tmpFolder.newFile("file1"); + FileOutputStream output = new FileOutputStream(file1); + output.write("1\n2".getBytes()); + output.close(); + + File file2 = tmpFolder.newFile("file2"); + output = new FileOutputStream(file2); + output.write("3\n4\n".getBytes()); + output.close(); + + File file3 = tmpFolder.newFile("file3"); + output = new FileOutputStream(file3); + output.write("5".getBytes()); + output.close(); + + + TextSource source = new TextSource<>( + tmpFolder.getRoot() + "/file*", + true /* strip newlines */, + null, null, StringUtf8Coder.of()); + + Set records = new TreeSet<>(); + try (Source.SourceIterator iterator = source.iterator()) { + while (iterator.hasNext()) { + records.add(iterator.next()); + } + } + + Assert.assertEquals(records.toString(), 5, records.size()); + Assert.assertTrue(records.contains("1")); + Assert.assertTrue(records.contains("2")); + Assert.assertTrue(records.contains("3")); + Assert.assertTrue(records.contains("4")); + Assert.assertTrue(records.contains("5")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java new file mode 100644 index 000000000000..18777b2aa394 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests of InstanceBuilder. + */ +@RunWith(JUnit4.class) +public class InstanceBuilderTest { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @SuppressWarnings("unused") + private static TupleTag createTag(String id) { + return new TupleTag(id); + } + + @Test + public void testFullNameLookup() throws Exception { + TupleTag tag = InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello world!") + .build(); + + Assert.assertEquals("hello world!", tag.getId()); + } + + @Test + public void testConstructor() throws Exception { + TupleTag tag = InstanceBuilder.ofType(TupleTag.class) + .withArg(String.class, "hello world!") + .build(); + + Assert.assertEquals("hello world!", tag.getId()); + } + + @Test + public void testBadMethod() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("Unable to find factory method")); + + InstanceBuilder.ofType(String.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("nonexistantFactoryMethod") + .withArg(String.class, "hello") + .withArg(String.class, " world!") + .build(); + } + + @Test + public void testBadArgs() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("Unable to find factory method")); + + InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello") + .withArg(Integer.class, 42) + .build(); + } + + @Test + public void testBadReturnType() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("must be assignable to String")); + + InstanceBuilder.ofType(String.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello") + .build(); + } + + @Test + public void testWrongType() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("must be assignable to TupleTag")); + + InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .build(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java new file mode 100644 index 000000000000..8ec3012da448 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.ListJobMessagesResponse; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for MonitoringUtil. + */ +@RunWith(JUnit4.class) +public class MonitoringUtilTest { + private static final String PROJECT_ID = "someProject"; + private static final String JOB_ID = "1234"; + + @Test + public void testGetJobMessages() throws IOException { + Dataflow.V1b3.Projects.Jobs.Messages mockMessages = + mock(Dataflow.V1b3.Projects.Jobs.Messages.class); + + // Two requests are needed to get all the messages. + Dataflow.V1b3.Projects.Jobs.Messages.List firstRequest = + mock(Dataflow.V1b3.Projects.Jobs.Messages.List.class); + Dataflow.V1b3.Projects.Jobs.Messages.List secondRequest = + mock(Dataflow.V1b3.Projects.Jobs.Messages.List.class); + + when(mockMessages.list(PROJECT_ID, JOB_ID)) + .thenReturn(firstRequest) + .thenReturn(secondRequest); + + ListJobMessagesResponse firstResponse = new ListJobMessagesResponse(); + firstResponse.setJobMessages(new ArrayList()); + for (int i = 0; i < 100; ++i) { + JobMessage message = new JobMessage(); + message.setId("message_" + i); + message.setTime(TimeUtil.toCloudTime(new Instant(i))); + firstResponse.getJobMessages().add(message); + } + String pageToken = "page_token"; + firstResponse.setNextPageToken(pageToken); + + ListJobMessagesResponse secondResponse = new ListJobMessagesResponse(); + secondResponse.setJobMessages(new ArrayList()); + for (int i = 100; i < 150; ++i) { + JobMessage message = new JobMessage(); + message.setId("message_" + i); + message.setTime(TimeUtil.toCloudTime(new Instant(i))); + secondResponse.getJobMessages().add(message); + } + + when(firstRequest.execute()).thenReturn(firstResponse); + when(secondRequest.execute()).thenReturn(secondResponse); + + MonitoringUtil util = new MonitoringUtil(PROJECT_ID, mockMessages); + + List messages = util.getJobMessages(JOB_ID, -1); + + verify(secondRequest).setPageToken(pageToken); + + assertEquals(150, messages.size()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java new file mode 100644 index 000000000000..3692411a4a75 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link PTuple}. */ +@RunWith(JUnit4.class) +public final class PTupleTest { + @Test + public void accessingNullVoidValuesShouldNotCauseExceptions() { + TupleTag tag = new TupleTag() {}; + PTuple tuple = PTuple.of(tag, null); + assertTrue(tuple.has(tag)); + assertThat(tuple.get(tag), is(nullValue())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java new file mode 100644 index 000000000000..7d923c2fcdb7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java @@ -0,0 +1,342 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.io.Files; +import com.google.common.io.LineReader; + +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.Pipe; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; + +/** Tests for PackageUtil. */ +@RunWith(JUnit4.class) +public class PackageUtilTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public FastNanoClockAndSleeper fastNanoClockAndSleeper = new FastNanoClockAndSleeper(); + + @Mock + GcsUtil mockGcsUtil; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testPackageNamingWithFileHavingExtension() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + + DataflowPackage target = PackageUtil.createPackage(tmpFile.getAbsolutePath(), gcsStaging, null); + + assertEquals("file-SAzzqSB2zmoIgNHC9A2G0A.txt", target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/file-SAzzqSB2zmoIgNHC9A2G0A.txt", + target.getLocation()); + } + + @Test + public void testPackageNamingWithFileMissingExtension() throws Exception { + File tmpFile = tmpFolder.newFile("file"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + + DataflowPackage target = PackageUtil.createPackage(tmpFile.getAbsolutePath(), gcsStaging, null); + + assertEquals("file-SAzzqSB2zmoIgNHC9A2G0A", target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/file-SAzzqSB2zmoIgNHC9A2G0A", + target.getLocation()); + } + + @Test + public void testPackageNamingWithDirectory() throws Exception { + File tmpDirectory = tmpFolder.newFolder("folder"); + File tmpFile = tmpFolder.newFile("folder/file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + + DataflowPackage target = + PackageUtil.createPackage(tmpDirectory.getAbsolutePath(), gcsStaging, null); + + assertEquals("folder-9MHI5fxducQ06t3IG9MC-g.zip", target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/folder-9MHI5fxducQ06t3IG9MC-g.zip", + target.getLocation()); + } + + @Test + public void testPackageNamingWithFilesHavingSameContentsButDifferentNames() throws Exception { + tmpFolder.newFolder("folder1"); + File tmpDirectory1 = tmpFolder.newFolder("folder1/folderA"); + File tmpFile1 = tmpFolder.newFile("folder1/folderA/uniqueName1"); + Files.write("This is a test!", tmpFile1, StandardCharsets.UTF_8); + + tmpFolder.newFolder("folder2"); + File tmpDirectory2 = tmpFolder.newFolder("folder2/folderA"); + File tmpFile2 = tmpFolder.newFile("folder2/folderA/uniqueName2"); + Files.write("This is a test!", tmpFile2, StandardCharsets.UTF_8); + + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + + DataflowPackage target1 = + PackageUtil.createPackage(tmpDirectory1.getAbsolutePath(), gcsStaging, null); + DataflowPackage target2 = + PackageUtil.createPackage(tmpDirectory2.getAbsolutePath(), gcsStaging, null); + + assertFalse(target1.getName().equals(target2.getName())); + assertFalse(target1.getLocation().equals(target2.getLocation())); + } + + @Test + public void testPackageNamingWithDirectoriesHavingSameContentsButDifferentNames() + throws Exception { + tmpFolder.newFolder("folder1"); + File tmpDirectory1 = tmpFolder.newFolder("folder1/folderA"); + tmpFolder.newFolder("folder1/folderA/uniqueName1"); + + tmpFolder.newFolder("folder2"); + File tmpDirectory2 = tmpFolder.newFolder("folder2/folderA"); + tmpFolder.newFolder("folder2/folderA/uniqueName2"); + + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + + DataflowPackage target1 = + PackageUtil.createPackage(tmpDirectory1.getAbsolutePath(), gcsStaging, null); + DataflowPackage target2 = + PackageUtil.createPackage(tmpDirectory2.getAbsolutePath(), gcsStaging, null); + + assertFalse(target1.getName().equals(target2.getName())); + assertFalse(target1.getLocation().equals(target2.getLocation())); + } + + @Test + public void testPackageUploadWithFileSucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpFile.getAbsolutePath()), gcsStaging); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertEquals("file-SAzzqSB2zmoIgNHC9A2G0A.txt", target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/file-SAzzqSB2zmoIgNHC9A2G0A.txt", + target.getLocation()); + assertEquals("This is a test!", + new LineReader(Channels.newReader(pipe.source(), "UTF-8")).readLine()); + } + + @Test + public void testPackageUploadWithDirectorySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + tmpFolder.newFolder("folder/empty_directory"); + tmpFolder.newFolder("folder/directory"); + File tmpFile1 = tmpFolder.newFile("folder/file.txt"); + File tmpFile2 = tmpFolder.newFile("folder/directory/file.txt"); + Files.write("This is a test!", tmpFile1, StandardCharsets.UTF_8); + Files.write("This is also a test!", tmpFile2, StandardCharsets.UTF_8); + + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpDirectory.getAbsolutePath()), gcsStaging); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + ZipInputStream inputStream = new ZipInputStream(Channels.newInputStream(pipe.source())); + List zipEntryNames = new ArrayList<>(); + for (ZipEntry entry = inputStream.getNextEntry(); entry != null; + entry = inputStream.getNextEntry()) { + zipEntryNames.add(entry.getName()); + } + assertTrue(CoreMatchers.hasItems("directory/file.txt", "empty_directory/", "file.txt").matches( + zipEntryNames)); + } + + @Test + public void testPackageUploadWithEmptyDirectorySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpDirectory.getAbsolutePath()), gcsStaging); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertEquals("folder-wstW9MW_ZW-soJhufroDCA.zip", target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/folder-wstW9MW_ZW-soJhufroDCA.zip", + target.getLocation()); + assertNull(new ZipInputStream(Channels.newInputStream(pipe.source())).getNextEntry()); + } + + @Test(expected = RuntimeException.class) + public void testPackageUploadFailsWhenIOExceptionThrown() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .thenThrow(new IOException("Upload error")); + + try { + PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpFile.getAbsolutePath()), gcsStaging, fastNanoClockAndSleeper); + } finally { + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil, times(5)).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + } + + @Test + public void testPackageUploadEventuallySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .thenThrow(new IOException("410 Gone")) // First attempt fails + .thenReturn(pipe.sink()); // second attempt succeeds + + try { + PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpFile.getAbsolutePath()), + gcsStaging, + fastNanoClockAndSleeper); + } finally { + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil, times(2)).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + } + + @Test + public void testPackageUploadIsSkippedWhenFileAlreadyExists() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(tmpFile.length()); + + PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpFile.getAbsolutePath()), gcsStaging); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verifyNoMoreInteractions(mockGcsUtil); + } + + @Test + public void testPackageUploadIsNotSkippedWhenSizesAreDifferent() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + tmpFolder.newFolder("folder/empty_directory"); + tmpFolder.newFolder("folder/directory"); + File tmpFile1 = tmpFolder.newFile("folder/file.txt"); + File tmpFile2 = tmpFolder.newFile("folder/directory/file.txt"); + Files.write("This is a test!", tmpFile1, StandardCharsets.UTF_8); + Files.write("This is also a test!", tmpFile2, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(Long.MAX_VALUE); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(tmpDirectory.getAbsolutePath()), gcsStaging); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + + @Test + public void testPackageUploadWithExplicitPackageName() throws Exception { + Pipe pipe = Pipe.open(); + File tmpFile = tmpFolder.newFile("file.txt"); + Files.write("This is a test!", tmpFile, StandardCharsets.UTF_8); + GcsPath gcsStaging = GcsPath.fromComponents("somebucket", "base/path"); + final String overriddenName = "alias.txt"; + + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(-1L); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElementsToGcs(mockGcsUtil, + ImmutableList.of(overriddenName + "=" + tmpFile.getAbsolutePath()), gcsStaging); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertEquals(overriddenName, target.getName()); + assertEquals("storage.googleapis.com/somebucket/base/path/file-SAzzqSB2zmoIgNHC9A2G0A.txt", + target.getLocation()); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java new file mode 100644 index 000000000000..45924560630b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java @@ -0,0 +1,234 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; +import com.google.api.services.storage.Storage; + +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.security.PrivateKey; + +/** + * Tests for RetryHttpRequestInitializer. + */ +@RunWith(JUnit4.class) +public class RetryHttpRequestInitializerTest { + + @Mock private Credential mockCredential; + @Mock private PrivateKey mockPrivateKey; + @Mock private LowLevelHttpRequest mockLowLevelRequest; + @Mock private LowLevelHttpResponse mockLowLevelResponse; + + private final JsonFactory jsonFactory = JacksonFactory.getDefaultInstance(); + private Storage storage; + + // Used to test retrying a request more than the default 10 times. + static class MockNanoClock implements NanoClock { + private int timesMs[] = {500, 750, 1125, 1688, 2531, 3797, 5695, 8543, + 12814, 19222, 28833, 43249, 64873, 97310, 145965, 218945, 328420}; + private int i = 0; + + @Override + public long nanoTime() { + return timesMs[i++ / 2] * 1000000; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + HttpTransport lowLevelTransport = new HttpTransport() { + @Override + protected LowLevelHttpRequest buildRequest(String method, String url) + throws IOException { + return mockLowLevelRequest; + } + }; + + // Retry initializer will pass through to credential, since we can have + // only a single HttpRequestInitializer, and we use multiple Credential + // types in the SDK, not all of which allow for retry configuration. + RetryHttpRequestInitializer initializer = new RetryHttpRequestInitializer( + mockCredential, new MockNanoClock(), new Sleeper() { + @Override + public void sleep(long millis) throws InterruptedException {} + }); + storage = new Storage.Builder(lowLevelTransport, jsonFactory, initializer) + .setApplicationName("test").build(); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockPrivateKey); + verifyNoMoreInteractions(mockLowLevelRequest); + verifyNoMoreInteractions(mockCredential); + } + + @Test + public void testBasicOperation() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a non-retriable error is not retried. + */ + @Test + public void testErrorCodeForbidden() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(403) // Non-retryable error. + .thenReturn(200); // Shouldn't happen. + + try { + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + } catch (HttpResponseException e) { + Assert.assertThat(e.getMessage(), Matchers.containsString("403")); + } + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a retriable error is retried. + */ + @Test + public void testRetryableError() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse) + .thenReturn(mockLowLevelResponse) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(503) // Retryable + .thenReturn(429) // We also retry on 429 Too Many Requests. + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest, times(3)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(3)).execute(); + verify(mockLowLevelResponse, times(3)).getStatusCode(); + } + + /** + * Tests that an IOException is retried. + */ + @Test + public void testThrowIOException() throws IOException { + when(mockLowLevelRequest.execute()) + .thenThrow(new IOException("Fake Error")) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest, times(2)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(2)).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a retryable error is retried enough times. + */ + @Test + public void testRetryableErrorRetryEnoughTimes() throws IOException { + when(mockLowLevelRequest.execute()).thenReturn(mockLowLevelResponse); + final int retries = 10; + when(mockLowLevelResponse.getStatusCode()).thenAnswer(new Answer(){ + int n = 0; + @Override + public Integer answer(InvocationOnMock invocation) { + return (n++ < retries - 1) ? 503 : 200; + }}); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockLowLevelRequest, atLeastOnce()).addHeader(anyString(), + anyString()); + verify(mockLowLevelRequest, times(retries)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(retries)).execute(); + verify(mockLowLevelResponse, times(retries)).getStatusCode(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java new file mode 100644 index 000000000000..90f10cdc9713 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for SerializableUtils. + */ +@RunWith(JUnit4.class) +public class SerializableUtilsTest { + static class TestClass implements Serializable { + final String stringValue; + final int intValue; + + public TestClass(String stringValue, int intValue) { + this.stringValue = stringValue; + this.intValue = intValue; + } + } + + @Test + public void testTranscode() { + String stringValue = "hi bob"; + int intValue = 42; + + TestClass testObject = new TestClass(stringValue, intValue); + + Object copy = + SerializableUtils.deserializeFromByteArray( + SerializableUtils.serializeToByteArray(testObject), + "a TestObject"); + + Assert.assertThat(copy, new IsInstanceOf(TestClass.class)); + TestClass testCopy = (TestClass) copy; + + Assert.assertEquals(stringValue, testCopy.stringValue); + Assert.assertEquals(intValue, testCopy.intValue); + } + + @Test + public void testDeserializationError() { + try { + SerializableUtils.deserializeFromByteArray( + "this isn't legal".getBytes(), + "a bogus string"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "unable to deserialize a bogus string")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java new file mode 100644 index 000000000000..40e2cc00f650 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDouble; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests Serializer implementation. + */ +@RunWith(JUnit4.class) +@Ignore +public class SerializerTest { + /** + * A POJO to use for testing serialization. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, + property = PropertyNames.OBJECT_TYPE_NAME) + public static class TestRecord { + // TODO: When we apply property name typing to all non-final classes, the + // annotation on this class should be removed. + public String name; + public boolean ok; + public int value; + public double dValue; + } + + @Test + public void testStatefulDeserialization() { + CloudObject object = CloudObject.forClass(TestRecord.class); + + addString(object, "name", "foobar"); + addBoolean(object, "ok", true); + addLong(object, "value", 42L); + addDouble(object, "dValue", .25); + + TestRecord record = Serializer.deserialize(object, TestRecord.class); + Assert.assertEquals("foobar", record.name); + Assert.assertEquals(true, record.ok); + Assert.assertEquals(42L, record.value); + Assert.assertEquals(0.25, record.dValue, 0.0001); + } + + private static class InjectedTestRecord { + private final String n; + private final int v; + + public InjectedTestRecord( + @JsonProperty("name") String name, + @JsonProperty("value") int value) { + this.n = name; + this.v = value; + } + + public String getName() { + return n; + } + public int getValue() { + return v; + } + } + + @Test + public void testDeserializationInjection() { + CloudObject object = CloudObject.forClass(InjectedTestRecord.class); + addString(object, "name", "foobar"); + addLong(object, "value", 42L); + + InjectedTestRecord record = + Serializer.deserialize(object, InjectedTestRecord.class); + + Assert.assertEquals("foobar", record.getName()); + Assert.assertEquals(42L, record.getValue()); + } + + private static class FactoryInjectedTestRecord { + @JsonCreator + public static FactoryInjectedTestRecord of( + @JsonProperty("name") String name, + @JsonProperty("value") int value) { + return new FactoryInjectedTestRecord(name, value); + } + + private final String n; + private final int v; + + private FactoryInjectedTestRecord(String name, int value) { + this.n = name; + this.v = value; + } + + public String getName() { + return n; + } + public int getValue() { + return v; + } + } + + @Test + public void testDeserializationFactoryInjection() { + CloudObject object = CloudObject.forClass(FactoryInjectedTestRecord.class); + addString(object, "name", "foobar"); + addLong(object, "value", 42L); + + FactoryInjectedTestRecord record = + Serializer.deserialize(object, FactoryInjectedTestRecord.class); + Assert.assertEquals("foobar", record.getName()); + Assert.assertEquals(42L, record.getValue()); + } + + private static class DerivedTestRecord extends TestRecord { + public String derived; + } + + @Test + public void testSubclassDeserialization() { + CloudObject object = CloudObject.forClass(DerivedTestRecord.class); + + addString(object, "name", "foobar"); + addBoolean(object, "ok", true); + addLong(object, "value", 42L); + addDouble(object, "dValue", .25); + addString(object, "derived", "baz"); + + TestRecord result = Serializer.deserialize(object, TestRecord.class); + Assert.assertThat(result, Matchers.instanceOf(DerivedTestRecord.class)); + + DerivedTestRecord record = (DerivedTestRecord) result; + Assert.assertEquals("foobar", record.name); + Assert.assertEquals(true, record.ok); + Assert.assertEquals(42L, record.value); + Assert.assertEquals(0.25, record.dValue, 0.0001); + Assert.assertEquals("baz", record.derived); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFnTest.java new file mode 100644 index 000000000000..94c44c707d0f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamingGroupAlsoByWindowsDoFnTest.java @@ -0,0 +1,282 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.WindowUtils.windowToString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowingFn; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link StreamingGroupAlsoByWindowsDoFn}. */ +@RunWith(JUnit4.class) +public class StreamingGroupAlsoByWindowsDoFnTest { + ExecutionContext execContext; + CounterSet counters; + TupleTag>> outputTag; + + @Before public void setUp() { + execContext = new DirectModeExecutionContext(); + counters = new CounterSet(); + outputTag = new TupleTag<>(); + } + + @Test public void testEmpty() throws Exception { + DoFnRunner>, + KV>, List> runner = + makeRunner(FixedWindows.of(Duration.millis(10))); + + runner.startBundle(); + + runner.finishBundle(); + + List>> result = runner.getReceiver(outputTag); + + assertEquals(0, result.size()); + } + + @Test public void testFixedWindows() throws Exception { + DoFnRunner>, + KV>, List> runner = + makeRunner(FixedWindows.of(Duration.millis(10))); + + Coder windowCoder = FixedWindows.of(Duration.millis(10)).windowCoder(); + + runner.startBundle(); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v1")), + new Instant(1), + Arrays.asList(window(0, 10)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v2")), + new Instant(2), + Arrays.asList(window(0, 10)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v0")), + new Instant(0), + Arrays.asList(window(0, 10)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v3")), + new Instant(13), + Arrays.asList(window(10, 20)))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(0, 10), windowCoder), + new Instant(9), "k"))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(10, 20), windowCoder), + new Instant(19), "k"))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(2, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.containsInAnyOrder("v0", "v1", "v2")); + assertEquals(new Instant(9), item0.getTimestamp()); + assertThat(item0.getWindows(), Matchers.contains(window(0, 10))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.containsInAnyOrder("v3")); + assertEquals(new Instant(19), item1.getTimestamp()); + assertThat(item1.getWindows(), Matchers.contains(window(10, 20))); + } + + @Test public void testSlidingWindows() throws Exception { + DoFnRunner>, + KV>, List> runner = + makeRunner(SlidingWindows.of(Duration.millis(20)).every(Duration.millis(10))); + + Coder windowCoder = + SlidingWindows.of(Duration.millis(10)).every(Duration.millis(10)).windowCoder(); + + runner.startBundle(); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v1")), + new Instant(5), + Arrays.asList(window(-10, 10), window(0, 20)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v0")), + new Instant(2), + Arrays.asList(window(-10, 10), window(0, 20)))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(-10, 10), windowCoder), + new Instant(9), "k"))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v2")), + new Instant(5), + Arrays.asList(window(0, 20), window(10, 30)))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(0, 20), windowCoder), + new Instant(19), "k"))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(10, 30), windowCoder), + new Instant(29), "k"))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(3, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.containsInAnyOrder("v0", "v1")); + assertEquals(new Instant(9), item0.getTimestamp()); + assertThat(item0.getWindows(), Matchers.contains(window(-10, 10))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.containsInAnyOrder("v0", "v1", "v2")); + assertEquals(new Instant(19), item1.getTimestamp()); + assertThat(item1.getWindows(), Matchers.contains(window(0, 20))); + + WindowedValue>> item2 = result.get(2); + assertEquals("k", item2.getValue().getKey()); + assertThat(item2.getValue().getValue(), Matchers.containsInAnyOrder("v2")); + assertEquals(new Instant(29), item2.getTimestamp()); + assertThat(item2.getWindows(), Matchers.contains(window(10, 30))); + } + + @Test public void testSessions() throws Exception { + DoFnRunner>, + KV>, List> runner = + makeRunner(Sessions.withGapDuration(Duration.millis(10))); + + Coder windowCoder = + Sessions.withGapDuration(Duration.millis(10)).windowCoder(); + + runner.startBundle(); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v1")), + new Instant(0), + Arrays.asList(window(0, 10)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v2")), + new Instant(5), + Arrays.asList(window(5, 15)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v3")), + new Instant(15), + Arrays.asList(window(15, 25)))); + + runner.processElement(WindowedValue.of( + TimerOrElement.element(KV.of("k", "v0")), + new Instant(3), + Arrays.asList(window(3, 13)))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(0, 15), windowCoder), + new Instant(14), "k"))); + + runner.processElement(WindowedValue.valueInEmptyWindows( + TimerOrElement.>timer( + windowToString((IntervalWindow) window(15, 25), windowCoder), + new Instant(24), "k"))); + + runner.finishBundle(); + + List>>> result = runner.getReceiver(outputTag); + + assertEquals(2, result.size()); + + WindowedValue>> item0 = result.get(0); + assertEquals("k", item0.getValue().getKey()); + assertThat(item0.getValue().getValue(), Matchers.containsInAnyOrder("v0", "v1", "v2")); + assertEquals(new Instant(14), item0.getTimestamp()); + assertThat(item0.getWindows(), Matchers.contains(window(0, 15))); + + WindowedValue>> item1 = result.get(1); + assertEquals("k", item1.getValue().getKey()); + assertThat(item1.getValue().getValue(), Matchers.containsInAnyOrder("v3")); + assertEquals(new Instant(24), item1.getTimestamp()); + assertThat(item1.getWindows(), Matchers.contains(window(15, 25))); + } + + + private DoFnRunner>, + KV>, List> makeRunner( + WindowingFn windowingStrategy) { + + StreamingGroupAlsoByWindowsDoFn, IntervalWindow> fn = + StreamingGroupAlsoByWindowsDoFn.create(windowingStrategy, StringUtf8Coder.of()); + + DoFnRunner>, + KV>, List> runner = + DoFnRunner.createWithListOutputs( + PipelineOptionsFactory.create(), + fn, + PTuple.empty(), + outputTag, + new ArrayList>(), + execContext.createStepContext("merge"), + counters.getAddCounterMutator()); + + return runner; + } + + private BoundedWindow window(long start, long end) { + return new IntervalWindow(new Instant(start), new Instant(end)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java new file mode 100644 index 000000000000..bf1a3193b7e3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Tests for StringUtils. + */ +@RunWith(JUnit4.class) +public class StringUtilsTest { + @Test + public void testTranscodeEmptyByteArray() { + byte[] bytes = { }; + String string = ""; + assertEquals(string, StringUtils.byteArrayToJsonString(bytes)); + assertArrayEquals(bytes, StringUtils.jsonStringToByteArray(string)); + } + + @Test + public void testTranscodeMixedByteArray() { + byte[] bytes = { + 0, 5, 12, 16, 31, 32, 65, 66, 126, 127, (byte) 128, (byte) 255, 67, 0 }; + String string = "%00%05%0c%10%1f AB~%7f%80%ffC%00"; + assertEquals(string, StringUtils.byteArrayToJsonString(bytes)); + assertArrayEquals(bytes, StringUtils.jsonStringToByteArray(string)); + } + + /** + * Inner class for simple name test. + */ + private class EmbeddedDoFn { + // Returns an anonymous inner class. + private EmbeddedDoFn getEmbedded() { + return new EmbeddedDoFn(){}; + } + } + + @Test + public void testSimpleName() { + assertEquals("Embedded", + StringUtils.approximateSimpleName(EmbeddedDoFn.class)); + } + + @Test + public void testAnonSimpleName() { + EmbeddedDoFn anon = new EmbeddedDoFn(){}; + + Pattern p = Pattern.compile("StringUtilsTest\\$[0-9]+"); + Matcher m = p.matcher(StringUtils.approximateSimpleName(anon.getClass())); + assertThat(m.matches(), is(true)); + } + + @Test + public void testNestedSimpleName() { + EmbeddedDoFn fn = new EmbeddedDoFn(); + EmbeddedDoFn anon = fn.getEmbedded(); + + // Expect to find "Embedded$1" + Pattern p = Pattern.compile("Embedded\\$[0-9]+"); + Matcher m = p.matcher(StringUtils.approximateSimpleName(anon.getClass())); + assertThat(m.matches(), is(true)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java new file mode 100644 index 000000000000..9b8cc208fca9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java @@ -0,0 +1,177 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDouble; +import static com.google.cloud.dataflow.sdk.util.Structs.addList; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addLongs; +import static com.google.cloud.dataflow.sdk.util.Structs.addNull; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; +import static com.google.cloud.dataflow.sdk.util.Structs.getBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.getDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.getLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getObject; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static com.google.cloud.dataflow.sdk.util.Structs.getStrings; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests for Structs. + */ +@RunWith(JUnit4.class) +public class StructsTest { + private List> makeCloudObjects() { + List> objects = new ArrayList<>(); + { + CloudObject o = CloudObject.forClassName("string"); + addString(o, "singletonStringKey", "stringValue"); + objects.add(o); + } + { + CloudObject o = CloudObject.forClassName("long"); + addLong(o, "singletonLongKey", 42L); + objects.add(o); + } + return objects; + } + + private Map makeCloudDictionary() { + Map o = new HashMap<>(); + addList(o, "emptyKey", Collections.>emptyList()); + addNull(o, "noStringsKey"); + addString(o, "singletonStringKey", "stringValue"); + addStringList(o, "multipleStringsKey", Arrays.asList("hi", "there", "bob")); + addLongs(o, "multipleLongsKey", 47L, 1L << 42, -5L); + addLong(o, "singletonLongKey", 42L); + addDouble(o, "singletonDoubleKey", 3.14); + addBoolean(o, "singletonBooleanKey", true); + addNull(o, "noObjectsKey"); + addList(o, "multipleObjectsKey", makeCloudObjects()); + return o; + } + + @Test + public void testGetStringParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + "stringValue", + getString(o, "singletonStringKey")); + Assert.assertEquals( + "stringValue", + getString(o, "singletonStringKey", "defaultValue")); + Assert.assertEquals( + "defaultValue", + getString(o, "missingKey", "defaultValue")); + + try { + getString(o, "missingKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString( + "didn't find required parameter missingKey")); + } + + try { + getString(o, "noStringsKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + + Assert.assertThat(getStrings(o, "noStringsKey", null), Matchers.emptyIterable()); + Assert.assertThat(getObject(o, "noStringsKey").keySet(), Matchers.emptyIterable()); + Assert.assertThat(getDictionary(o, "noStringsKey").keySet(), Matchers.emptyIterable()); + Assert.assertThat(getDictionary(o, "noStringsKey", null).keySet(), + Matchers.emptyIterable()); + + try { + getString(o, "multipleStringsKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + + try { + getString(o, "emptyKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + } + + @Test + public void testGetBooleanParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + true, + getBoolean(o, "singletonBooleanKey", false)); + Assert.assertEquals( + false, + getBoolean(o, "missingKey", false)); + + try { + getBoolean(o, "emptyKey", false); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a boolean")); + } + } + + @Test + public void testGetLongParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + (Long) 42L, + getLong(o, "singletonLongKey", 666L)); + Assert.assertEquals( + (Long) 666L, + getLong(o, "missingKey", 666L)); + + try { + getLong(o, "emptyKey", 666L); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not an int")); + } + } + + // TODO: Test builder operations. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java new file mode 100644 index 000000000000..1faebeba7c0e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudTime; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TimeUtil}. */ +@RunWith(JUnit4.class) +public final class TimeUtilTest { + @Test + public void toCloudTimeShouldPrintTimeStrings() { + assertEquals("1970-01-01T00:00:00Z", toCloudTime(new Instant(0))); + assertEquals("1970-01-01T00:00:00.001Z", toCloudTime(new Instant(1))); + } + + @Test + public void fromCloudTimeShouldParseTimeStrings() { + assertEquals(new Instant(0), fromCloudTime("1970-01-01T00:00:00Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001001Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000000Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000001Z")); + assertNull(fromCloudTime("")); + assertNull(fromCloudTime("1970-01-01T00:00:00")); + } + + @Test + public void toCloudDurationShouldPrintDurationStrings() { + assertEquals("0s", toCloudDuration(Duration.ZERO)); + assertEquals("4s", toCloudDuration(Duration.millis(4000))); + assertEquals("4.001s", toCloudDuration(Duration.millis(4001))); + } + + @Test + public void fromCloudDurationShouldParseDurationStrings() { + assertEquals(Duration.millis(4000), fromCloudDuration("4s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001001s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000000s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000001s")); + assertNull(fromCloudDuration("")); + assertNull(fromCloudDuration("4")); + assertNull(fromCloudDuration("4.1")); + assertNull(fromCloudDuration("4.1s")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java new file mode 100644 index 000000000000..d6b771bd0512 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java @@ -0,0 +1,281 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.IOException; + +/** Unit tests for {@link VarInt}. */ +@RunWith(JUnit4.class) +public class VarIntTest { + @Rule public final ExpectedException thrown = ExpectedException.none(); + + // Long values to check for boundary cases. + private static final long[] LONG_VALUES = { + 0, + 1, + 127, + 128, + 16383, + 16384, + 2097151, + 2097152, + 268435455, + 268435456, + 34359738367L, + 34359738368L, + 9223372036854775807L, + -9223372036854775808L, + -1, + }; + + // VarInt encoding of the above VALUES. + private static final byte[][] LONG_ENCODED = { + // 0 + { 0x00 }, + // 1 + { 0x01 }, + // 127 + { 0x7f }, + // 128 + { (byte) 0x80, 0x01 }, + // 16383 + { (byte) 0xff, 0x7f }, + // 16834 + { (byte) 0x80, (byte) 0x80, 0x01 }, + // 2097151 + { (byte) 0xff, (byte) 0xff, 0x7f }, + // 2097152 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // 268435455 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x7f }, + // 268435456 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // 34359738367 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x7f }, + // 34359738368 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, + 0x01 }, + // 9223372036854775807 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // -9223372036854775808L + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, + (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // -1 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x01 } + }; + + // Integer values to check for boundary cases. + private static final int[] INT_VALUES = { + 0, + 1, + 127, + 128, + 16383, + 16384, + 2097151, + 2097152, + 268435455, + 268435456, + 2147483647, + -2147483648, + -1, + }; + + // VarInt encoding of the above VALUES. + private static final byte[][] INT_ENCODED = { + // 0 + { (byte) 0x00 }, + // 1 + { (byte) 0x01 }, + // 127 + { (byte) 0x7f }, + // 128 + { (byte) 0x80, (byte) 0x01 }, + // 16383 + { (byte) 0xff, (byte) 0x7f }, + // 16834 + { (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 2097151 + { (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // 2097152 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 268435455 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // 268435456 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 2147483647 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x07 }, + // -2147483648 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x08 }, + // -1 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x0f } + }; + + private static byte[] encodeInt(int v) throws IOException { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + VarInt.encode(v, stream); + return stream.toByteArray(); + } + + private static byte[] encodeLong(long v) throws IOException { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + VarInt.encode(v, stream); + return stream.toByteArray(); + } + + private static int decodeInt(byte[] encoded) throws IOException { + ByteArrayInputStream stream = new ByteArrayInputStream(encoded); + return VarInt.decodeInt(stream); + } + + private static long decodeLong(byte[] encoded) throws IOException { + ByteArrayInputStream stream = new ByteArrayInputStream(encoded); + return VarInt.decodeLong(stream); + } + + @Test + public void decodeValues() throws IOException { + assertEquals(LONG_VALUES.length, LONG_ENCODED.length); + for (int i = 0; i < LONG_ENCODED.length; ++i) { + ByteArrayInputStream stream = new ByteArrayInputStream(LONG_ENCODED[i]); + long parsed = VarInt.decodeLong(stream); + assertEquals(LONG_VALUES[i], parsed); + assertEquals(-1, stream.read()); + } + + assertEquals(INT_VALUES.length, INT_ENCODED.length); + for (int i = 0; i < INT_ENCODED.length; ++i) { + ByteArrayInputStream stream = new ByteArrayInputStream(INT_ENCODED[i]); + int parsed = VarInt.decodeInt(stream); + assertEquals(INT_VALUES[i], parsed); + assertEquals(-1, stream.read()); + } + } + + @Test + public void encodeValuesAndGetLength() throws IOException { + assertEquals(LONG_VALUES.length, LONG_ENCODED.length); + for (int i = 0; i < LONG_VALUES.length; ++i) { + byte[] encoded = encodeLong(LONG_VALUES[i]); + assertThat(encoded, equalTo(LONG_ENCODED[i])); + assertEquals(LONG_ENCODED[i].length, VarInt.getLength(LONG_VALUES[i])); + } + + assertEquals(INT_VALUES.length, INT_ENCODED.length); + for (int i = 0; i < INT_VALUES.length; ++i) { + byte[] encoded = encodeInt(INT_VALUES[i]); + assertThat(encoded, equalTo(INT_ENCODED[i])); + assertEquals(INT_ENCODED[i].length, VarInt.getLength(INT_VALUES[i])); + } + } + + @Test + public void decodeThrowsExceptionForOverflow() throws IOException { + final byte[] tooLargeNumber = + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x02 }; + + thrown.expect(IOException.class); + + long parsed = decodeLong(tooLargeNumber); + } + + @Test + public void decodeThrowsExceptionForIntOverflow() throws IOException { + byte[] encoded = encodeLong(1L << 32); + + thrown.expect(IOException.class); + + int parsed = decodeInt(encoded); + } + + @Test + public void decodeThrowsExceptionForIntUnderflow() throws IOException { + byte[] encoded = encodeLong(-1); + + thrown.expect(IOException.class); + + int parsed = decodeInt(encoded); + } + + @Test + public void decodeThrowsExceptionForNonterminated() throws IOException { + final byte[] nonTerminatedNumber = + { (byte) 0xff, (byte) 0xff }; + + thrown.expect(IOException.class); + + long parsed = decodeLong(nonTerminatedNumber); + } + + @Test + public void decodeParsesEncodedValues() throws IOException { + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + for (int i = 10; i < Integer.MAX_VALUE; i = (int) (i * 1.1)) { + VarInt.encode(i, outStream); + VarInt.encode(-i, outStream); + } + for (long i = 10; i < Long.MAX_VALUE; i = (long) (i * 1.1)) { + VarInt.encode(i, outStream); + VarInt.encode(-i, outStream); + } + + ByteArrayInputStream inStream = + new ByteArrayInputStream(outStream.toByteArray()); + for (int i = 10; i < Integer.MAX_VALUE; i = (int) (i * 1.1)) { + assertEquals(i, VarInt.decodeInt(inStream)); + assertEquals(-i, VarInt.decodeInt(inStream)); + } + for (long i = 10; i < Long.MAX_VALUE; i = (long) (i * 1.1)) { + assertEquals(i, VarInt.decodeLong(inStream)); + assertEquals(-i, VarInt.decodeLong(inStream)); + } + } + + @Test + public void endOfFileThrowsException() throws Exception { + ByteArrayInputStream inStream = + new ByteArrayInputStream(new byte[0]); + thrown.expect(EOFException.class); + VarInt.decodeInt(inStream); + } + + @Test + public void unterminatedThrowsException() throws Exception { + byte[] e = encodeLong(Long.MAX_VALUE); + byte[] s = new byte[1]; + s[0] = e[0]; + ByteArrayInputStream inStream = new ByteArrayInputStream(s); + thrown.expect(IOException.class); + VarInt.decodeInt(inStream); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java new file mode 100644 index 000000000000..67f21f549092 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** Test case for {@link WindowedValue}. */ +@RunWith(JUnit4.class) +public class WindowedValueTest { + @Test + public void testWindowedValueCoder() throws CoderException { + Instant timestamp = new Instant(1234); + WindowedValue value = WindowedValue.of( + "abc", + new Instant(1234), + Arrays.asList(new IntervalWindow(timestamp, timestamp.plus(1000)), + new IntervalWindow(timestamp.plus(1000), timestamp.plus(2000)))); + + Coder> windowedValueCoder = + WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder()); + + byte[] encodedValue = CoderUtils.encodeToByteArray(windowedValueCoder, value); + WindowedValue decodedValue = + CoderUtils.decodeFromByteArray(windowedValueCoder, encodedValue); + + Assert.assertEquals(value.getValue(), decodedValue.getValue()); + Assert.assertEquals(value.getTimestamp(), decodedValue.getTimestamp()); + Assert.assertArrayEquals(value.getWindows().toArray(), decodedValue.getWindows().toArray()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java new file mode 100644 index 000000000000..c8fa6c2fab5a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SET; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link CounterSet}. + */ +@RunWith(JUnit4.class) +public class CounterSetTest { + @Test + public void testSet() { + CounterSet set = new CounterSet(); + assertTrue(set.add(Counter.longs("c1", SUM))); + assertFalse(set.add(Counter.longs("c1", SUM))); + assertTrue(set.add(Counter.longs("c2", MAX))); + assertEquals(2, set.size()); + } + + @Test + public void testAddCounterMutator() { + CounterSet set = new CounterSet(); + Counter c1 = Counter.longs("c1", SUM); + Counter c1SecondInstance = Counter.longs("c1", SUM); + Counter c1IncompatibleInstance = Counter.longs("c1", SET); + Counter c2 = Counter.longs("c2", MAX); + Counter c2IncompatibleInstance = Counter.doubles("c2", MAX); + + assertEquals(c1, set.getAddCounterMutator().addCounter(c1)); + assertEquals(c2, set.getAddCounterMutator().addCounter(c2)); + + assertEquals(c1, set.getAddCounterMutator().addCounter(c1SecondInstance)); + + try { + set.getAddCounterMutator().addCounter(c1IncompatibleInstance); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + // Expected. + } + + try { + set.getAddCounterMutator().addCounter(c2IncompatibleInstance); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + // Expected. + } + + assertEquals(2, set.size()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java new file mode 100644 index 000000000000..ff40e0d06f18 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java @@ -0,0 +1,743 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.Values.asBoolean; +import static com.google.cloud.dataflow.sdk.util.Values.asDouble; +import static com.google.cloud.dataflow.sdk.util.Values.asLong; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.AND; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MIN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.OR; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SET; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.util.CloudCounterUtils; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.common.collect.Sets; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Unit tests for the {@link Counter} API. + */ +@RunWith(JUnit4.class) +public class CounterTest { + + private static MetricUpdate flush(Counter c) { + // TODO: Move this out into a separate Counter test. + return CounterTestUtils.extractCounterUpdate(c, true); + } + + private static final double EPSILON = 0.00000000001; + + @Test + public void testNameKindAndCloudCounterRepresentation() { + Counter c1 = Counter.longs("c1", SUM); + Counter c2 = Counter.doubles("c2", MAX); + Counter c3 = Counter.strings("c3", SET); + Counter c4 = Counter.doubles("c4", MEAN); + Counter c5 = Counter.ints("c5", MIN); + Counter c6 = Counter.booleans("c6", AND); + Counter c7 = Counter.booleans("c7", OR); + + assertEquals("c1", c1.getName()); + assertEquals(SUM, c1.getKind()); + MetricUpdate cc = flush(c1); + assertEquals("c1", cc.getName().getName()); + assertEquals("SUM", cc.getKind()); + assertEquals(0L, asLong(cc.getScalar()).longValue()); + c1.addValue(123L).addValue(-13L); + cc = flush(c1); + assertEquals(110L, asLong(cc.getScalar()).longValue()); + + assertEquals("c2", c2.getName()); + assertEquals(MAX, c2.getKind()); + cc = flush(c2); + assertEquals("c2", cc.getName().getName()); + assertEquals("MAX", cc.getKind()); + assertEquals(Double.MIN_VALUE, asDouble(cc.getScalar()), EPSILON); + c2.resetToValue(0.0).addValue(Math.PI).addValue(Math.E); + cc = flush(c2); + assertEquals(Math.PI, asDouble(cc.getScalar()), EPSILON); + + assertEquals("c3", c3.getName()); + assertEquals(SET, c3.getKind()); + cc = flush(c3); // empty sets are not sent to the service + assertEquals(null, cc); + c3.addValue("abc").addValue("e").addValue("abc"); + cc = flush(c3); + assertEquals("c3", cc.getName().getName()); + assertEquals("SET", cc.getKind()); + Set s = (Set) cc.getSet(); + assertEquals(2, s.size()); + assertTrue(s.containsAll(Arrays.asList( + CloudObject.forString("e"), + CloudObject.forString("abc")))); + + assertEquals("c4", c4.getName()); + assertEquals(MEAN, c4.getKind()); + cc = flush(c4); // zero-count means are not sent to the service + assertEquals(null, cc); + c4.addValue(Math.PI).addValue(Math.E).addValue(Math.sqrt(2)); + cc = flush(c4); + assertEquals("c4", cc.getName().getName()); + assertEquals("MEAN", cc.getKind()); + Object ms = cc.getMeanSum(); + Object mc = cc.getMeanCount(); + assertEquals(Math.PI + Math.E + Math.sqrt(2), asDouble(ms), EPSILON); + assertEquals(3, asLong(mc).longValue()); + c4.addValue(2.0).addValue(5.0); + cc = flush(c4); + ms = cc.getMeanSum(); + mc = cc.getMeanCount(); + assertEquals(7.0, asDouble(ms), EPSILON); + assertEquals(2L, asLong(mc).longValue()); + + assertEquals("c5", c5.getName()); + assertEquals(MIN, c5.getKind()); + cc = flush(c5); + assertEquals("c5", cc.getName().getName()); + assertEquals("MIN", cc.getKind()); + assertEquals(Integer.MAX_VALUE, asLong(cc.getScalar()).longValue()); + c5.addValue(123).addValue(-13); + cc = flush(c5); + assertEquals(-13, asLong(cc.getScalar()).longValue()); + + assertEquals("c6", c6.getName()); + assertEquals(AND, c6.getKind()); + cc = flush(c6); + assertEquals("c6", cc.getName().getName()); + assertEquals("AND", cc.getKind()); + assertEquals(true, asBoolean(cc.getScalar())); + c6.addValue(false); + cc = flush(c6); + assertEquals(false, asBoolean(cc.getScalar())); + + assertEquals("c7", c7.getName()); + assertEquals(OR, c7.getKind()); + cc = flush(c7); + assertEquals("c7", cc.getName().getName()); + assertEquals("OR", cc.getKind()); + assertEquals(false, asBoolean(cc.getScalar())); + c7.addValue(true); + cc = flush(c7); + assertEquals(true, asBoolean(cc.getScalar())); + } + + @Test + public void testCompatibility() { + // Equal counters are compatible, of all kinds. + assertTrue( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c", SUM))); + assertTrue( + Counter.ints("c", SUM).isCompatibleWith(Counter.ints("c", SUM))); + assertTrue( + Counter.doubles("c", SUM).isCompatibleWith(Counter.doubles("c", SUM))); + assertTrue( + Counter.strings("c", SET).isCompatibleWith(Counter.strings("c", SET))); + assertTrue( + Counter.booleans("c", OR).isCompatibleWith( + Counter.booleans("c", OR))); + + // The name, kind, and type of the counter must match. + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c2", SUM))); + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c", MAX))); + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.ints("c", SUM))); + + // The value of the counters are ignored. + assertTrue( + Counter.longs("c", SUM).resetToValue(666L).isCompatibleWith( + Counter.longs("c", SUM).resetToValue(42L))); + } + + + private void assertOK(long total, long delta, Counter c) { + assertEquals(total, c.getTotalAggregate().longValue()); + assertEquals(delta, c.getDeltaAggregate().longValue()); + } + + private void assertOK(double total, double delta, Counter c) { + assertEquals(total, asDouble(c.getTotalAggregate()), EPSILON); + assertEquals(delta, asDouble(c.getDeltaAggregate()), EPSILON); + } + + + // Tests for SUM. + + @Test + public void testSumLong() { + Counter c = Counter.longs("sum-long", SUM); + long expectedTotal = 0; + long expectedDelta = 0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(0L); + expectedTotal += 55; + expectedDelta += 55; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 174; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = 0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(15L).addValue(42L); + expectedTotal += 57; + expectedDelta += 57; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(17L).addValue(49L); + expectedTotal = expectedDelta = 166; + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testSumDouble() { + Counter c = Counter.doubles("sum-double", SUM); + double expectedTotal = 0.0; + double expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(0.0); + expectedTotal += Math.E + Math.PI; + expectedDelta += Math.E + Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(2)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = Math.sqrt(2) + 2 * Math.PI + 3 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedTotal += 7 * Math.PI + 5 * Math.E; + expectedDelta += 7 * Math.PI + 5 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(17.0).addValue(49.0); + expectedTotal = expectedDelta = Math.sqrt(17.0) + 17.0 + 49.0; + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MAX. + + @Test + public void testMaxLong() { + Counter c = Counter.longs("max-long", MAX); + long expectedTotal = Long.MIN_VALUE; + long expectedDelta = Long.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(Long.MIN_VALUE); + expectedTotal = expectedDelta = 42; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 120; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Long.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(42L).addValue(15L); + expectedDelta = 42; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(171L).addValue(49L); + expectedTotal = expectedDelta = 171; + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testMaxDouble() { + Counter c = Counter.doubles("max-double", MAX); + double expectedTotal = Double.MIN_VALUE; + double expectedDelta = Double.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(Double.MIN_VALUE); + expectedTotal = expectedDelta = Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(12345)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = Math.sqrt(12345); + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Double.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedDelta = 7 * Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(171.0).addValue(49.0); + expectedTotal = expectedDelta = 171.0; + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MIN. + + @Test + public void testMinLong() { + Counter c = Counter.longs("min-long", MIN); + long expectedTotal = Long.MAX_VALUE; + long expectedDelta = Long.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(Long.MAX_VALUE); + expectedTotal = expectedDelta = 13; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 17; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Long.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(42L).addValue(18L); + expectedDelta = 18; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(171L).addValue(49L); + expectedTotal = expectedDelta = 49; + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testMinDouble() { + Counter c = Counter.doubles("min-double", MIN); + double expectedTotal = Double.MAX_VALUE; + double expectedDelta = Double.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(Double.MAX_VALUE); + expectedTotal = expectedDelta = Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(12345)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = 2 * Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Double.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedDelta = 5 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(171.0).addValue(0.0); + expectedTotal = expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MEAN. + + private void assertMean(long s, long sd, long c, long cd, Counter cn) { + assertEquals(s, cn.getTotalAggregate().longValue()); + assertEquals(sd, cn.getDeltaAggregate().longValue()); + assertEquals(c, cn.getTotalCount()); + assertEquals(cd, cn.getDeltaCount()); + } + + private void assertMean(double s, double sd, long c, long cd, + Counter cn) { + assertEquals(s, cn.getTotalAggregate().doubleValue(), EPSILON); + assertEquals(sd, cn.getDeltaAggregate().doubleValue(), EPSILON); + assertEquals(c, cn.getTotalCount()); + assertEquals(cd, cn.getDeltaCount()); + } + + @Test + public void testMeanLong() { + Counter c = Counter.longs("mean-long", MEAN); + long expTotal = 0; + long expDelta = 0; + long expCountTotal = 0; + long expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(13L).addValue(42L).addValue(0L); + expTotal += 55; + expDelta += 55; + expCountTotal += 3; + expCountDelta += 3; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetToValue(1L, 120L).addValue(17L).addValue(37L); + expTotal = expDelta = 174; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + flush(c); + expDelta = 0; + expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(15L).addValue(42L); + expTotal += 57; + expDelta += 57; + expCountTotal += 2; + expCountDelta += 2; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetToValue(3L, 100L).addValue(17L).addValue(49L); + expTotal = expDelta = 166; + expCountTotal = expCountDelta = 5; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + } + + @Test + public void testMeanDouble() { + Counter c = Counter.doubles("mean-double", MEAN); + double expTotal = 0.0; + double expDelta = 0.0; + long expCountTotal = 0; + long expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(0.0); + expTotal += Math.E + Math.PI; + expDelta += Math.E + Math.PI; + expCountTotal += 3; + expCountDelta += 3; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetToValue(1L, Math.sqrt(2)).addValue(2 * Math.PI).addValue(3 * Math.E); + expTotal = expDelta = Math.sqrt(2) + 2 * Math.PI + 3 * Math.E; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + flush(c); + expDelta = 0.0; + expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expTotal += 7 * Math.PI + 5 * Math.E; + expDelta += 7 * Math.PI + 5 * Math.E; + expCountTotal += 2; + expCountDelta += 2; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetToValue(3L, Math.sqrt(17)).addValue(17.0).addValue(49.0); + expTotal = expDelta = Math.sqrt(17.0) + 17.0 + 49.0; + expCountTotal = expCountDelta = 5; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + } + + + // Tests for SET. + + private void assertSet(Set total, Set delta, Counter c) { + assertTrue(total.containsAll(c.getTotalSet())); + assertTrue(c.getTotalSet().containsAll(total)); + assertTrue(delta.containsAll(c.getDeltaSet())); + assertTrue(c.getDeltaSet().containsAll(delta)); + } + + @Test + public void testSetLong() { + Counter c = Counter.longs("set-long", SET); + HashSet expectedTotal = new HashSet<>(); + HashSet expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(13L); + expectedTotal = expectedDelta = Sets.newHashSet(13L, 42L); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = Sets.newHashSet(120L, 17L, 37L); + assertSet(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue(42L).addValue(18L); + expectedTotal.addAll(Arrays.asList(42L, 18L)); + expectedDelta = Sets.newHashSet(42L, 18L); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(171L).addValue(49L); + expectedTotal = expectedDelta = Sets.newHashSet(100L, 171L, 49L); + assertSet(expectedTotal, expectedDelta, c); + } + + @Test + public void testSetDouble() { + Counter c = Counter.doubles("set-double", SET); + HashSet expectedTotal = new HashSet<>(); + HashSet expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI); + expectedTotal = expectedDelta = Sets.newHashSet(Math.E, Math.PI); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(12345)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = + expectedDelta = Sets.newHashSet(Math.sqrt(12345), 2 * Math.PI, 3 * Math.E); + assertSet(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedTotal.addAll(Arrays.asList(7 * Math.PI, 5 * Math.E)); + expectedDelta = Sets.newHashSet(7 * Math.PI, 5 * Math.E); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(171.0).addValue(0.0); + expectedTotal = expectedDelta = Sets.newHashSet(Math.sqrt(17), 171.0, 0.0); + assertSet(expectedTotal, expectedDelta, c); + } + + @Test + public void testSetString() { + Counter c = Counter.strings("set-string", SET); + HashSet expectedTotal = new HashSet<>(); + HashSet expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue("a").addValue("b").addValue("a"); + expectedTotal = expectedDelta = Sets.newHashSet("a", "b"); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue("c").addValue("d").addValue("e"); + expectedTotal = expectedDelta = Sets.newHashSet("c", "d", "e"); + assertSet(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = new HashSet<>(); + assertSet(expectedTotal, expectedDelta, c); + + c.addValue("b").addValue("f"); + expectedTotal.addAll(Arrays.asList("b", "f")); + expectedDelta = Sets.newHashSet("b", "f"); + assertSet(expectedTotal, expectedDelta, c); + + c.resetToValue("g").addValue("h").addValue("i"); + expectedTotal = expectedDelta = Sets.newHashSet("g", "h", "i"); + assertSet(expectedTotal, expectedDelta, c); + } + + + // Test for AND and OR. + + private void assertBool(boolean total, boolean delta, Counter c) { + assertEquals(total, c.getTotalAggregate().booleanValue()); + assertEquals(delta, c.getDeltaAggregate().booleanValue()); + } + + @Test + public void testBoolAnd() { + Counter c = Counter.booleans("bool-and", AND); + boolean expectedTotal = true; + boolean expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.resetToValue(true).addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + assertBool(expectedTotal, expectedDelta, c); + } + + @Test + public void testBoolOr() { + Counter c = Counter.booleans("bool-or", OR); + boolean expectedTotal = false; + boolean expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.resetToValue(false).addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + assertBool(expectedTotal, expectedDelta, c); + } + + + // Incompatibility tests. + + @Test(expected = IllegalArgumentException.class) + public void testSumBool() { + Counter.booleans("counter", SUM); + } + + @Test(expected = IllegalArgumentException.class) + public void testSumString() { + Counter.strings("counter", SUM); + } + + @Test(expected = IllegalArgumentException.class) + public void testMinBool() { + Counter.booleans("counter", MIN); + } + + @Test(expected = IllegalArgumentException.class) + public void testMinString() { + Counter.strings("counter", MIN); + } + + @Test(expected = IllegalArgumentException.class) + public void testMaxBool() { + Counter.booleans("counter", MAX); + } + + @Test(expected = IllegalArgumentException.class) + public void testMaxString() { + Counter.strings("counter", MAX); + } + + @Test(expected = IllegalArgumentException.class) + public void testMeanBool() { + Counter.booleans("counter", MEAN); + } + + @Test(expected = IllegalArgumentException.class) + public void testMeanString() { + Counter.strings("counter", MEAN); + } + + @Test(expected = IllegalArgumentException.class) + public void testSetBool() { + Counter.booleans("counter", SET); + } + + @Test(expected = IllegalArgumentException.class) + public void testAndLong() { + Counter.longs("counter", AND); + } + + @Test(expected = IllegalArgumentException.class) + public void testAndDouble() { + Counter.doubles("counter", AND); + } + + @Test(expected = IllegalArgumentException.class) + public void testAndString() { + Counter.strings("counter", AND); + } + + @Test(expected = IllegalArgumentException.class) + public void testOrLong() { + Counter.longs("counter", OR); + } + + @Test(expected = IllegalArgumentException.class) + public void testOrDouble() { + Counter.doubles("counter", OR); + } + + @Test(expected = IllegalArgumentException.class) + public void testOrString() { + Counter.strings("counter", OR); + } + + @Test + public void testExtraction() { + Counter[] counters = {Counter.longs("c1", SUM), + Counter.doubles("c2", MAX), + Counter.strings("c3", SET)}; + CounterSet set = new CounterSet(); + for (Counter c : counters) { + set.addCounter(c); + } + + List cloudCountersFromSet = CloudCounterUtils.extractCounters(set, true); + + List cloudCountersFromArray = + CounterTestUtils.extractCounterUpdates(Arrays.asList(counters), true); + + assertEquals(cloudCountersFromArray.size(), cloudCountersFromSet.size()); + for (int i = 0; i < cloudCountersFromArray.size(); i++) { + assertEquals(cloudCountersFromArray.get(i), cloudCountersFromSet.get(i)); + } + + assertEquals(2, cloudCountersFromSet.size()); // empty set was ignored + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java new file mode 100644 index 000000000000..9c428476e28f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java @@ -0,0 +1,123 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; + +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.CloudCounterUtils; + +import org.junit.Assert; + +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; + +/** + * Utilities for testing {@link Counter}s. + */ +public class CounterTestUtils { + + /** + * Extracts a MetricUpdate update from the given counter. This is used mainly + * for testing. + * + * @param extractDelta specifies whether or not to extract the cumulative + * aggregate value or the delta since the last extraction. + */ + public static MetricUpdate extractCounterUpdate(Counter counter, + boolean extractDelta) { + // This may be invoked asynchronously with regular counter updates but + // access to counter data is synchronized, so this is safe. + return CloudCounterUtils.extractCounter(counter, extractDelta); + } + + /** + * Extracts MetricUpdate updates from the given counters. This is used mainly + * for testing. + * + * @param extractDelta specifies whether or not to extract the cumulative + * aggregate values or the deltas since the last extraction. + */ + public static List extractCounterUpdates( + Collection> counters, boolean extractDelta) { + // This may be invoked asynchronously with regular counter updates but + // access to counter data is synchronized, so this is safe. Note however + // that the result is NOT an atomic snapshot across all given counters. + List cloudCounters = new ArrayList<>(counters.size()); + for (Counter counter : counters) { + MetricUpdate cloudCounter = extractCounterUpdate(counter, extractDelta); + if (null != cloudCounter) { + cloudCounters.add(cloudCounter); + } + } + return cloudCounters; + } + + + // These methods expose a counter's values for testing. + + public static T getTotalAggregate(Counter counter) { + return counter.getTotalAggregate(); + } + + public static T getDeltaAggregate(Counter counter) { + return counter.getDeltaAggregate(); + } + + public static long getTotalCount(Counter counter) { + return counter.getTotalCount(); + } + + public static long getDeltaCount(Counter counter) { + return counter.getDeltaCount(); + } + + public static Set getTotalSet(Counter counter) { + return counter.getTotalSet(); + } + + public static Set getDeltaSet(Counter counter) { + return counter.getDeltaSet(); + } + + /** + * A utility method that passes the given (unencoded) elements through + * coder's registerByteSizeObserver() and encode() methods, and confirms + * they are mutually consistent. This is useful for testing coder + * implementations. + */ + public static void testByteCount(Coder coder, Coder.Context context, Object[] elements) + throws Exception { + Counter meanByteCount = Counter.longs("meanByteCount", MEAN); + ElementByteSizeObserver observer = new ElementByteSizeObserver(meanByteCount); + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + for (Object elem : elements) { + coder.registerByteSizeObserver(elem, observer, context); + coder.encode(elem, os, context); + observer.advance(); + } + long expectedLength = os.toByteArray().length; + + Assert.assertEquals(expectedLength, (long) getTotalAggregate(meanByteCount)); + Assert.assertEquals(elements.length, (long) getTotalCount(meanByteCount)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/MetricTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/MetricTest.java new file mode 100644 index 000000000000..0c60901ca0a6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/MetricTest.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.util.common.Metric.DoubleMetric; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link Metric}. */ +@RunWith(JUnit4.class) +public class MetricTest { + @Test + public void testDoubleMetric() { + String name = "metric-name"; + double value = 3.14; + + DoubleMetric doubleMetric = new DoubleMetric(name, value); + + assertEquals(name, doubleMetric.getName()); + assertEquals((Double) value, doubleMetric.getValue()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReaderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReaderTest.java new file mode 100644 index 000000000000..5a4149471728 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/BatchingShuffleEntryReaderTest.java @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.api.client.util.Lists.newArrayList; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.runners.worker.ByteArrayShufflePosition; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Unit tests for {@link BatchingShuffleEntryReader}. */ +@RunWith(JUnit4.class) +public final class BatchingShuffleEntryReaderTest { + private static final byte[] KEY = {0xA}; + private static final byte[] SKEY = {0xB}; + private static final byte[] VALUE = {0xC}; + private static final ShufflePosition START_POSITION = + ByteArrayShufflePosition.of("aaa".getBytes()); + private static final ShufflePosition END_POSITION = + ByteArrayShufflePosition.of("zzz".getBytes()); + private static final ShufflePosition NEXT_START_POSITION = + ByteArrayShufflePosition.of("next".getBytes()); + private static final ShufflePosition SECOND_NEXT_START_POSITION = + ByteArrayShufflePosition.of("next-second".getBytes()); + + @Mock private ShuffleBatchReader batchReader; + private ShuffleEntryReader reader; + + @Before + public void initMocksAndReader() { + MockitoAnnotations.initMocks(this); + reader = new BatchingShuffleEntryReader(batchReader); + } + + @Test + public void readerCanRead() throws Exception { + ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE); + ArrayList entries = new ArrayList<>(); + entries.add(e1); + entries.add(e2); + when(batchReader.read(START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(entries, null)); + List results = newArrayList(reader.read(START_POSITION, END_POSITION)); + assertThat(results, contains(e1, e2)); + } + + @Test + public void readerIteratorCanBeCopied() throws Exception { + ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE); + ArrayList entries = new ArrayList<>(); + entries.add(e1); + entries.add(e2); + when(batchReader.read(START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(entries, null)); + Reiterator it = reader.read(START_POSITION, END_POSITION); + assertThat(it.hasNext(), equalTo(Boolean.TRUE)); + assertThat(it.next(), equalTo(e1)); + Reiterator copy = it.copy(); + assertThat(it.hasNext(), equalTo(Boolean.TRUE)); + assertThat(it.next(), equalTo(e2)); + assertThat(it.hasNext(), equalTo(Boolean.FALSE)); + assertThat(copy.hasNext(), equalTo(Boolean.TRUE)); + assertThat(copy.next(), equalTo(e2)); + assertThat(copy.hasNext(), equalTo(Boolean.FALSE)); + } + + @Test + public void readerShouldMergeMultipleBatchResults() throws Exception { + ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE); + List e1s = Collections.singletonList(e1); + ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE); + List e2s = Collections.singletonList(e2); + when(batchReader.read(START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(e1s, NEXT_START_POSITION)); + when(batchReader.read(NEXT_START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(e2s, null)); + List results = newArrayList(reader.read(START_POSITION, END_POSITION)); + assertThat(results, contains(e1, e2)); + + verify(batchReader).read(START_POSITION, END_POSITION); + verify(batchReader).read(NEXT_START_POSITION, END_POSITION); + verifyNoMoreInteractions(batchReader); + } + + @Test + public void readerShouldMergeMultipleBatchResultsIncludingEmptyShards() + throws Exception { + List e1s = new ArrayList<>(); + List e2s = new ArrayList<>(); + ShuffleEntry e3 = new ShuffleEntry(KEY, SKEY, VALUE); + List e3s = Collections.singletonList(e3); + when(batchReader.read(START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(e1s, NEXT_START_POSITION)); + when(batchReader.read(NEXT_START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(e2s, SECOND_NEXT_START_POSITION)); + when(batchReader.read(SECOND_NEXT_START_POSITION, END_POSITION)) + .thenReturn(new ShuffleBatchReader.Batch(e3s, null)); + List results = newArrayList(reader.read(START_POSITION, END_POSITION)); + assertThat(results, contains(e3)); + + verify(batchReader).read(START_POSITION, END_POSITION); + verify(batchReader).read(NEXT_START_POSITION, END_POSITION); + verify(batchReader).read(SECOND_NEXT_START_POSITION, END_POSITION); + verifyNoMoreInteractions(batchReader); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReaderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReaderTest.java new file mode 100644 index 000000000000..4175c9150596 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/CachingShuffleBatchReaderTest.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; + +/** Unit tests for {@link CachingShuffleBatchReader}. */ +@RunWith(JUnit4.class) +public final class CachingShuffleBatchReaderTest { + + private final ShuffleBatchReader.Batch testBatch = + new ShuffleBatchReader.Batch(new ArrayList(), null); + + @Test + public void readerShouldCacheReads() throws IOException { + ShuffleBatchReader base = mock(ShuffleBatchReader.class); + CachingShuffleBatchReader reader = new CachingShuffleBatchReader(base); + when(base.read(null, null)).thenReturn(testBatch); + // N.B. We need to capture the result of reader.read() in order to ensure + // that there's a strong reference to it, preventing it from being + // collected. Not that this should be an issue in tests, but it's good to + // be solid. + ShuffleBatchReader.Batch read = reader.read(null, null); + assertThat(read, equalTo(testBatch)); + assertThat(reader.read(null, null), equalTo(testBatch)); + assertThat(reader.read(null, null), equalTo(testBatch)); + assertThat(reader.read(null, null), equalTo(testBatch)); + assertThat(reader.read(null, null), equalTo(testBatch)); + verify(base, times(1)).read(null, null); + } + + @Test + public void readerShouldNotCacheExceptions() throws IOException { + ShuffleBatchReader base = mock(ShuffleBatchReader.class); + CachingShuffleBatchReader reader = new CachingShuffleBatchReader(base); + when(base.read(null, null)) + .thenThrow(new IOException("test")) + .thenReturn(testBatch); + try { + reader.read(null, null); + fail("expected an IOException"); + } catch (IOException e) { + // Nothing to do -- exception is expected. + } + assertThat(reader.read(null, null), equalTo(testBatch)); + verify(base, times(2)).read(null, null); + } + + @Test + public void readerShouldRereadClearedBatches() throws IOException { + ShuffleBatchReader base = mock(ShuffleBatchReader.class); + CachingShuffleBatchReader reader = new CachingShuffleBatchReader(base); + when(base.read(null, null)).thenReturn(testBatch); + ShuffleBatchReader.Batch read = reader.read(null, null); + assertThat(read, equalTo(testBatch)); + verify(base, times(1)).read(null, null); + CachingShuffleBatchReader.BatchRange range = + new CachingShuffleBatchReader.BatchRange(null, null); + CachingShuffleBatchReader.RangeReadReference ref = + reader.cache.get(range); + assertThat(ref, notNullValue()); + ref.clear(); + read = reader.read(null, null); + assertThat(read, equalTo(testBatch)); + verify(base, times(2)).read(null, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ExecutorTestUtils.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ExecutorTestUtils.java new file mode 100644 index 000000000000..0c678abe75d5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ExecutorTestUtils.java @@ -0,0 +1,238 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.ElementByteSizeObservableCoder; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservable; + +import org.junit.Assert; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Observable; +import java.util.Observer; + +/** + * Utilities for tests. + */ +public class ExecutorTestUtils { + // Do not instantiate. + private ExecutorTestUtils() { } + + /** An Operation with a specified number of outputs. */ + public static class TestOperation extends Operation { + public TestOperation(int numOutputs) { + this(numOutputs, new CounterSet()); + } + + TestOperation(int numOutputs, CounterSet counters) { + this(numOutputs, counters, "test-"); + } + + TestOperation(int numOutputs, CounterSet counters, String counterPrefix) { + this(numOutputs, counterPrefix, counters.getAddCounterMutator(), + new StateSampler(counterPrefix, counters.getAddCounterMutator())); + } + + TestOperation(int numOutputs, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super("TestOperation", + createOutputReceivers(numOutputs, counterPrefix, + addCounterMutator, stateSampler), + counterPrefix, + addCounterMutator, + stateSampler); + } + + private static OutputReceiver[] createOutputReceivers( + int numOutputs, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + OutputReceiver[] receivers = new OutputReceiver[numOutputs]; + for (int i = 0; i < numOutputs; i++) { + receivers[i] = new OutputReceiver( + "out_" + i, + new ElementByteSizeObservableCoder(StringUtf8Coder.of()), + counterPrefix, + addCounterMutator); + } + return receivers; + } + } + + /** An OutputReceiver that allows the output elements to be retrieved. */ + public static class TestReceiver extends OutputReceiver { + List outputElems = new ArrayList<>(); + + public TestReceiver(CounterSet counterSet) { + this("test_receiver_out", counterSet); + } + + public TestReceiver(Coder coder) { + this(coder, new CounterSet()); + } + + public TestReceiver(Coder coder, CounterSet counterSet) { + this("test_receiver_out", + new ElementByteSizeObservableCoder(coder), + counterSet, + "test-"); + } + + public TestReceiver(CounterSet counterSet, String counterPrefix) { + this("test_receiver_out", counterSet, counterPrefix); + } + + public TestReceiver(String outputName, CounterSet counterSet) { + this(outputName, counterSet, "test-"); + } + + public TestReceiver(String outputName, + CounterSet counterSet, String counterPrefix) { + this(outputName, + new ElementByteSizeObservableCoder(StringUtf8Coder.of()), + counterSet, + counterPrefix); + } + + public TestReceiver(ElementByteSizeObservable elementByteSizeObservable, + CounterSet counterSet, String counterPrefix) { + this("test_receiver_out", elementByteSizeObservable, + counterSet, counterPrefix); + } + + public TestReceiver(String outputName, + ElementByteSizeObservable elementByteSizeObservable, + CounterSet counterSet, String counterPrefix) { + super(outputName, + elementByteSizeObservable, + counterPrefix, + counterSet.getAddCounterMutator()); + } + + @Override + public void process(Object elem) throws Exception { + super.process(elem); + outputElems.add(elem); + } + + @Override + protected boolean sampleElement() { + return true; + } + } + + /** A {@code Source} that yields a specified set of values. */ + public static class TestSource extends Source { + List inputs = new ArrayList<>(); + + public void addInput(String... inputs) { + this.inputs.addAll(Arrays.asList(inputs)); + } + + @Override + public SourceIterator iterator() { + return new TestSourceIterator(inputs); + } + + class TestSourceIterator extends AbstractSourceIterator { + Iterator iter; + boolean closed = false; + + public TestSourceIterator(List inputs) { + iter = inputs.iterator(); + } + + @Override + public boolean hasNext() { return iter.hasNext(); } + + @Override + public String next() { + String next = iter.next(); + notifyElementRead(next.length()); + return next; + } + + @Override + public void close() { + Assert.assertFalse(closed); + closed = true; + } + } + } + + /** + * An Observer that stores all sizes into an ArrayList, to compare + * against the gold standard during testing. + */ + public static class TestSourceObserver implements Observer { + private final Source source; + private final List sizes; + + public TestSourceObserver(Source source) { + this(source, new ArrayList()); + } + + public TestSourceObserver(Source source, List sizes) { + this.source = source; + this.sizes = sizes; + source.addObserver(this); + } + + @Override + public void update(Observable obs, Object obj) { + sizes.add((int) (long) obj); + } + + public List getActualSizes() { + return sizes; + } + } + + /** A {@code Sink} that allows the output elements to be retrieved. */ + public static class TestSink extends Sink { + List outputElems = new ArrayList<>(); + boolean closed = false; + + @Override + public SinkWriter writer() { + return new TestSinkWriter(); + } + + class TestSinkWriter implements SinkWriter { + @Override + public long add(String outputElem) { + outputElems.add(outputElem); + return outputElem.length(); + } + + @Override + public void close() { + Assert.assertFalse(closed); + closed = true; + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperationTest.java new file mode 100644 index 000000000000..d0f8e747de7e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/FlattenOperationTest.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for FlattenOperation. + */ +@RunWith(JUnit4.class) +public class FlattenOperationTest { + @Test + public void testRunFlattenOperation() throws Exception { + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + ExecutorTestUtils.TestReceiver receiver = + new ExecutorTestUtils.TestReceiver(counterSet, counterPrefix); + + FlattenOperation flattenOperation = + new FlattenOperation(receiver, + counterPrefix, counterSet.getAddCounterMutator(), + stateSampler); + + flattenOperation.start(); + + flattenOperation.process("hi"); + flattenOperation.process("there"); + flattenOperation.process(""); + flattenOperation.process("bob"); + + flattenOperation.finish(); + + Assert.assertThat(receiver.outputElems, + CoreMatchers.hasItems("hi", "there", "", "bob")); + + Assert.assertEquals( + new CounterSet( + Counter.longs("test-FlattenOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-FlattenOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-FlattenOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-FlattenOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-FlattenOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-FlattenOperation-finish-msecs")).getAggregate(false)), + Counter.longs("test_receiver_out-ElementCount", SUM) + .resetToValue(4L), + Counter.longs("test_receiver_out-MeanByteCount", MEAN) + .resetToValue(4, 10L)), + counterSet); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutorTest.java new file mode 100644 index 000000000000..27017962ccc5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/MapTaskExecutorTest.java @@ -0,0 +1,290 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudPositionToSourcePosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterSet.AddCounterMutator; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestReceiver; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestSource; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for MapTaskExecutor. + */ +@RunWith(JUnit4.class) +public class MapTaskExecutorTest { + static class TestOperation extends Operation { + String label; + List log; + + private static CounterSet counterSet = new CounterSet(); + private static String counterPrefix = "test-"; + private static StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + + TestOperation(String label, List log) { + super(label, + new OutputReceiver[]{}, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + this.label = label; + this.log = log; + } + + TestOperation(String outputName, + String counterPrefix, + CounterSet.AddCounterMutator addCounterMutator, + StateSampler stateSampler, + long outputCount) { + super(outputName, new OutputReceiver[]{}, + counterPrefix, addCounterMutator, stateSampler); + addCounterMutator.addCounter( + Counter.longs(outputName + "-ElementCount", SUM) + .resetToValue(outputCount)); + } + + @Override + public void start() throws Exception { + super.start(); + log.add(label + " started"); + } + + @Override + public void finish() throws Exception { + log.add(label + " finished"); + super.finish(); + } + } + + // A mock ReadOperation fed to a MapTaskExecutor in test. + static class TestReadOperation extends ReadOperation { + private ApproximateProgress progress = null; + + TestReadOperation(OutputReceiver outputReceiver, + String counterPrefix, + AddCounterMutator addCounterMutator, + StateSampler stateSampler) { + super(new TestSource(), outputReceiver, + counterPrefix, addCounterMutator, stateSampler); + } + + @Override + public Source.Progress getProgress() { + return cloudProgressToSourceProgress(progress); + } + + @Override + public Source.Position proposeStopPosition( + Source.Progress proposedStopPosition) { + // Fakes the return with the same position as proposed. + return cloudPositionToSourcePosition( + sourceProgressToCloudProgress(proposedStopPosition) + .getPosition()); + } + + public void setProgress(ApproximateProgress progress) { + this.progress = progress; + } + } + + @Test + public void testExecuteMapTaskExecutor() throws Exception { + List log = new ArrayList<>(); + + List operations = Arrays.asList(new Operation[]{ + new TestOperation("o1", log), + new TestOperation("o2", log), + new TestOperation("o3", log)}); + + CounterSet counters = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counters.getAddCounterMutator()); + MapTaskExecutor executor = + new MapTaskExecutor(operations, counters, stateSampler); + + executor.execute(); + + Assert.assertThat(log, CoreMatchers.hasItems( + "o3 started", + "o2 started", + "o1 started", + "o1 finished", + "o2 finished", + "o3 finished")); + + executor.close(); + } + + @Test + public void testGetOutputCounters() throws Exception { + CounterSet counters = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counters.getAddCounterMutator()); + List operations = Arrays.asList(new Operation[]{ + new TestOperation( + "o1", counterPrefix, counters.getAddCounterMutator(), + stateSampler, 1), + new TestOperation( + "o2", counterPrefix, counters.getAddCounterMutator(), + stateSampler, 2), + new TestOperation( + "o3", counterPrefix, counters.getAddCounterMutator(), + stateSampler, 3)}); + + MapTaskExecutor executor = + new MapTaskExecutor(operations, counters, stateSampler); + + CounterSet counterSet = executor.getOutputCounters(); + Assert.assertEquals( + new CounterSet( + Counter.longs("o1-ElementCount", SUM).resetToValue(1L), + Counter.longs("test-o1-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o1-start-msecs")).getAggregate(false)), + Counter.longs("test-o1-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o1-process-msecs")).getAggregate(false)), + Counter.longs("test-o1-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o1-finish-msecs")).getAggregate(false)), + Counter.longs("o2-ElementCount", SUM).resetToValue(2L), + Counter.longs("test-o2-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o2-start-msecs")).getAggregate(false)), + Counter.longs("test-o2-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o2-process-msecs")).getAggregate(false)), + Counter.longs("test-o2-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o2-finish-msecs")).getAggregate(false)), + Counter.longs("o3-ElementCount", SUM).resetToValue(3L), + Counter.longs("test-o3-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o3-start-msecs")).getAggregate(false)), + Counter.longs("test-o3-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o3-process-msecs")).getAggregate(false)), + Counter.longs("test-o3-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-o3-finish-msecs")).getAggregate(false))), + counterSet); + + executor.close(); + } + + @Test + public void testGetReadOperation() throws Exception { + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + // Test MapTaskExecutor without a single operation. + MapTaskExecutor executor = + new MapTaskExecutor(new ArrayList(), + counterSet, stateSampler); + + try { + ReadOperation readOperation = executor.getReadOperation(); + Assert.fail("Expected IllegalStateException."); + } catch (IllegalStateException e) { + // Exception expected + } + + List operations = Arrays.asList(new Operation[]{ + new TestOperation("o1", + counterPrefix, counterSet.getAddCounterMutator(), + stateSampler, 1), + new TestOperation("o2", + counterPrefix, counterSet.getAddCounterMutator(), + stateSampler, 2)}); + // Test MapTaskExecutor without ReadOperation. + executor = new MapTaskExecutor(operations, counterSet, stateSampler); + + try { + ReadOperation readOperation = executor.getReadOperation(); + Assert.fail("Expected IllegalStateException."); + } catch (IllegalStateException e) { + // Exception expected + } + + executor.close(); + + TestReceiver receiver = new TestReceiver(counterSet, counterPrefix); + operations = Arrays.asList(new Operation[]{ + new TestReadOperation( + receiver, counterPrefix, counterSet.getAddCounterMutator(), + stateSampler)}); + executor = new MapTaskExecutor(operations, counterSet, stateSampler); + Assert.assertEquals(operations.get(0), executor.getReadOperation()); + executor.close(); + } + + @Test + public void testGetProgressAndRequestSplit() throws Exception { + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + TestReceiver receiver = new TestReceiver(counterSet, counterPrefix); + TestReadOperation operation = + new TestReadOperation(receiver, + counterPrefix, counterSet.getAddCounterMutator(), + stateSampler); + MapTaskExecutor executor = new MapTaskExecutor( + Arrays.asList(new Operation[]{operation}), counterSet, stateSampler); + + operation.setProgress(new ApproximateProgress().setPosition(makePosition(1L))); + Assert.assertEquals( + makePosition(1L), + sourceProgressToCloudProgress(executor.getWorkerProgress()).getPosition()); + Assert.assertEquals( + makePosition(1L), + sourcePositionToCloudPosition( + executor.proposeStopPosition( + cloudProgressToSourceProgress( + new ApproximateProgress().setPosition(makePosition(1L)))))); + + executor.close(); + } + + private com.google.api.services.dataflow.model.Position makePosition(long index) { + com.google.api.services.dataflow.model.Position position = + new com.google.api.services.dataflow.model.Position(); + position.setRecordIndex(index); + return position; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiverTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiverTest.java new file mode 100644 index 000000000000..08955ac564d7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/OutputReceiverTest.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.ElementByteSizeObservableCoder; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterTestUtils; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestReceiver; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for OutputReceiver. + */ +@RunWith(JUnit4.class) +public class OutputReceiverTest { + // We test OutputReceiver where every element is sampled. + static class TestOutputReceiver extends OutputReceiver { + public TestOutputReceiver() { + this(new CounterSet()); + } + + public TestOutputReceiver(CounterSet counters) { + super("output_name", + new ElementByteSizeObservableCoder(StringUtf8Coder.of()), + "test-", + counters.getAddCounterMutator()); + } + + @Override + protected boolean sampleElement() { + return true; + } + } + + @Test + public void testEmptyOutputReceiver() throws Exception { + TestOutputReceiver fanOut = new TestOutputReceiver(); + fanOut.process("hi"); + fanOut.process("bob"); + + Assert.assertEquals("output_name", fanOut.getName()); + Assert.assertEquals( + 2, + (long) CounterTestUtils.getTotalAggregate(fanOut.getElementCount())); + Assert.assertEquals( + 5, + (long) CounterTestUtils.getTotalAggregate(fanOut.getMeanByteCount())); + Assert.assertEquals( + 2, + (long) CounterTestUtils.getTotalCount(fanOut.getMeanByteCount())); + } + + @Test + public void testMultipleOutputReceiver() throws Exception { + TestOutputReceiver fanOut = new TestOutputReceiver(); + + CounterSet counters = new CounterSet(); + String counterPrefix = "test-"; + + TestReceiver receiver1 = new TestReceiver(counters, counterPrefix); + fanOut.addOutput(receiver1); + + TestReceiver receiver2 = new TestReceiver(counters, counterPrefix); + fanOut.addOutput(receiver2); + + fanOut.process("hi"); + fanOut.process("bob"); + + Assert.assertEquals("output_name", fanOut.getName()); + Assert.assertEquals( + 2, + (long) CounterTestUtils.getTotalAggregate(fanOut.getElementCount())); + Assert.assertEquals( + 5, + (long) CounterTestUtils.getTotalAggregate(fanOut.getMeanByteCount())); + Assert.assertEquals( + 2, + (long) CounterTestUtils.getTotalCount(fanOut.getMeanByteCount())); + Assert.assertThat(receiver1.outputElems, + CoreMatchers.hasItems("hi", "bob")); + Assert.assertThat(receiver2.outputElems, + CoreMatchers.hasItems("hi", "bob")); + } + + @Test(expected = ClassCastException.class) + public void testIncorrectType() throws Exception { + TestOutputReceiver fanOut = new TestOutputReceiver(); + fanOut.process(5); + } + + @Test(expected = CoderException.class) + public void testNullArgument() throws Exception { + TestOutputReceiver fanOut = new TestOutputReceiver(); + fanOut.process(null); + } + + @Test + public void testAddingCountersIntoCounterSet() throws Exception { + CounterSet counters = new CounterSet(); + TestOutputReceiver receiver = new TestOutputReceiver(counters); + + Assert.assertEquals( + new CounterSet( + Counter.longs("output_name-ElementCount", SUM) + .resetToValue(0L), + Counter.longs("output_name-MeanByteCount", MEAN) + .resetToValue(0, 0L)), + counters); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperationTest.java new file mode 100644 index 000000000000..b08266cbb4d8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ParDoOperationTest.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for ParDoOperation. + */ +@RunWith(JUnit4.class) +public class ParDoOperationTest { + static class TestParDoFn extends ParDoFn { + final OutputReceiver outputReceiver; + + public TestParDoFn(OutputReceiver outputReceiver) { + this.outputReceiver = outputReceiver; + } + + @Override + public void startBundle(final Receiver... receivers) throws Exception { + if (receivers.length != 1) { + throw new AssertionError( + "unexpected number of receivers for DoFn"); + } + + outputReceiver.process("x-start"); + } + + @Override + public void processElement(Object elem) throws Exception { + outputReceiver.process("y-" + elem); + } + + @Override + public void finishBundle() throws Exception { + outputReceiver.process("z-finish"); + } + } + + @Test + public void testRunParDoOperation() throws Exception { + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + ExecutorTestUtils.TestReceiver receiver = + new ExecutorTestUtils.TestReceiver(counterSet); + + ParDoOperation parDoOperation = + new ParDoOperation( + "ParDoOperation", + new TestParDoFn(receiver), + new OutputReceiver[]{ receiver }, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + + parDoOperation.start(); + + parDoOperation.process("hi"); + parDoOperation.process("there"); + parDoOperation.process(""); + parDoOperation.process("bob"); + + parDoOperation.finish(); + + Assert.assertThat( + receiver.outputElems, + CoreMatchers.hasItems( + "x-start", "y-hi", "y-there", "y-", "y-bob", "z-finish")); + + Assert.assertEquals( + new CounterSet( + Counter.longs("test-ParDoOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ParDoOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-ParDoOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ParDoOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-ParDoOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ParDoOperation-finish-msecs")).getAggregate(false)), + Counter.longs("test_receiver_out-ElementCount", SUM) + .resetToValue(6L), + Counter.longs("test_receiver_out-MeanByteCount", MEAN) + .resetToValue(6, 33L)), + counterSet); + } + + // TODO: Test side inputs. + // TODO: Test side outputs. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperationTest.java new file mode 100644 index 000000000000..620ac0c89894 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/PartialGroupByKeyOperationTest.java @@ -0,0 +1,397 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.isIn; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.CoderGroupingKeyCreator; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.CoderSizeEstimator; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.ElementByteSizeObservableCoder; +import com.google.cloud.dataflow.sdk.runners.worker.MapTaskExecutorFactory.PairInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestReceiver; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.BufferingGroupingTable; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.Combiner; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.CombiningGroupingTable; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.GroupingKeyCreator; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.SamplingSizeEstimator; +import com.google.cloud.dataflow.sdk.util.common.worker.PartialGroupByKeyOperation.SizeEstimator; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.hamcrest.Description; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.hamcrest.collection.IsIterableContainingInAnyOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * Tests for PartialGroupByKeyOperation. + */ +@RunWith(JUnit4.class) +public class PartialGroupByKeyOperationTest { + @Test + public void testRunPartialGroupByKeyOperation() throws Exception { + Coder keyCoder = StringUtf8Coder.of(); + Coder valueCoder = BigEndianIntegerCoder.of(); + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + TestReceiver receiver = + new TestReceiver( + new ElementByteSizeObservableCoder( + WindowedValue.getValueOnlyCoder( + KvCoder.of(keyCoder, IterableCoder.of(valueCoder)))), + counterSet, counterPrefix); + + PartialGroupByKeyOperation pgbkOperation = + new PartialGroupByKeyOperation(new CoderGroupingKeyCreator(keyCoder), + new CoderSizeEstimator(keyCoder), + new CoderSizeEstimator(valueCoder), + PairInfo.create(), + receiver, + counterPrefix, + counterSet.getAddCounterMutator(), + stateSampler); + + pgbkOperation.start(); + + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("hi", 4))); + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("there", 5))); + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("hi", 6))); + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("joe", 7))); + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("there", 8))); + pgbkOperation.process(WindowedValue.valueInEmptyWindows(KV.of("hi", 9))); + + pgbkOperation.finish(); + + assertThat(receiver.outputElems, + IsIterableContainingInAnyOrder.containsInAnyOrder( + WindowedValue.valueInEmptyWindows(KV.of("hi", Arrays.asList(4, 6, 9))), + WindowedValue.valueInEmptyWindows(KV.of("there", Arrays.asList(5, 8))), + WindowedValue.valueInEmptyWindows(KV.of("joe", Arrays.asList(7))))); + + // Exact counter values depend on size of encoded data. If encoding + // changes, then these expected counters should change to match. + assertEquals( + new CounterSet( + Counter.longs("test-PartialGroupByKeyOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-PartialGroupByKeyOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-PartialGroupByKeyOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-PartialGroupByKeyOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-PartialGroupByKeyOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-PartialGroupByKeyOperation-finish-msecs")).getAggregate(false)), + Counter.longs("test_receiver_out-ElementCount", SUM) + .resetToValue(3L), + Counter.longs("test_receiver_out-MeanByteCount", MEAN) + .resetToValue(3, 49L)), + counterSet); + } + + // TODO: Add tests about early flushing when the table fills. + + //////////////////////////////////////////////////////////////////////////// + // Tests for PartialGroupByKey internals. + + /** + * Return the key as its grouping key. + */ + public static class IdentityGroupingKeyCreator implements GroupingKeyCreator { + @Override + public Object createGroupingKey(Object key) { + return key; + } + } + + /** + * "Estimate" the size of longs by looking at their value. + */ + private static class IdentitySizeEstimator implements SizeEstimator { + public int calls = 0; + @Override + public long estimateSize(Long element) { + calls++; + return element; + } + } + + /** + * "Estimate" the size of strings by taking the tenth power of their length. + */ + private static class StringPowerSizeEstimator implements SizeEstimator { + @Override + public long estimateSize(String element) { + return (long) Math.pow(10, element.length()); + } + } + + @Test + public void testBufferingGroupingTable() throws Exception { + BufferingGroupingTable table = + new BufferingGroupingTable<>( + 1000, new IdentityGroupingKeyCreator(), PairInfo.create(), + new StringPowerSizeEstimator(), new StringPowerSizeEstimator()); + TestReceiver receiver = new TestReceiver( + WindowedValue.getValueOnlyCoder( + KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())))); + + table.put("A", "a", receiver); + table.put("B", "b1", receiver); + table.put("B", "b2", receiver); + table.put("C", "c", receiver); + assertThat(unwindowed(receiver.outputElems), empty()); + + table.put("C", "cccc", receiver); + assertThat(unwindowed(receiver.outputElems), + hasItem((Object) KV.of("C", Arrays.asList("c", "cccc")))); + + table.put("DDDD", "d", receiver); + assertThat(unwindowed(receiver.outputElems), + hasItem((Object) KV.of("DDDD", Arrays.asList("d")))); + + table.flush(receiver); + assertThat(unwindowed(receiver.outputElems), + IsIterableContainingInAnyOrder.containsInAnyOrder( + KV.of("A", Arrays.asList("a")), + KV.of("B", Arrays.asList("b1", "b2")), + KV.of("C", Arrays.asList("c", "cccc")), + KV.of("DDDD", Arrays.asList("d")))); + } + + @Test + public void testCombiningGroupingTable() throws Exception { + Combiner summingCombineFn = + new Combiner() { + public Long createAccumulator(Object key) { + return 0L; + } + public Long add(Object key, Long accumulator, Integer value) { + return accumulator + value; + } + public Long merge(Object key, Iterable accumulators) { + long sum = 0; + for (Long part : accumulators) { sum += part; } + return sum; + } + public Long extract(Object key, Long accumulator) { + return accumulator; + } + }; + + CombiningGroupingTable table = + new CombiningGroupingTable( + 1000, new IdentityGroupingKeyCreator(), PairInfo.create(), + summingCombineFn, + new StringPowerSizeEstimator(), new IdentitySizeEstimator()); + + TestReceiver receiver = new TestReceiver( + WindowedValue.getValueOnlyCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianLongCoder.of()))); + + table.put("A", 1, receiver); + table.put("B", 2, receiver); + table.put("B", 3, receiver); + table.put("C", 4, receiver); + assertThat(unwindowed(receiver.outputElems), empty()); + + table.put("C", 5000, receiver); + assertThat(unwindowed(receiver.outputElems), hasItem((Object) KV.of("C", 5004L))); + + table.put("DDDD", 6, receiver); + assertThat(unwindowed(receiver.outputElems), hasItem((Object) KV.of("DDDD", 6L))); + + table.flush(receiver); + assertThat(unwindowed(receiver.outputElems), + IsIterableContainingInAnyOrder.containsInAnyOrder( + KV.of("A", 1L), + KV.of("B", 2L + 3), + KV.of("C", 5000L + 4), + KV.of("DDDD", 6L))); + } + + private List unwindowed(Iterable windowed) { + List unwindowed = new ArrayList<>(); + for (Object withWindow : windowed) { + unwindowed.add(((WindowedValue) withWindow).getValue()); + } + return unwindowed; + } + + + //////////////////////////////////////////////////////////////////////////// + // Tests for the sampling size estimator. + + @Test + public void testSampleFlatSizes() throws Exception { + IdentitySizeEstimator underlying = new IdentitySizeEstimator(); + SizeEstimator estimator = + new SamplingSizeEstimator(underlying, 0.05, 1.0, 10, new Random(1)); + // First 10 elements are always sampled. + for (int k = 0; k < 10; k++) { + assertEquals(100, estimator.estimateSize(100L)); + assertEquals(k + 1, underlying.calls); + } + // Next 10 are sometimes sampled. + for (int k = 10; k < 20; k++) { + assertEquals(100, estimator.estimateSize(100L)); + } + assertThat(underlying.calls, between(11, 19)); + int initialCalls = underlying.calls; + // Next 1000 are sampled at about 5%. + for (int k = 20; k < 1020; k++) { + assertEquals(100, estimator.estimateSize(100L)); + } + assertThat(underlying.calls - initialCalls, between(40, 60)); + } + + @Test + public void testSampleBoringSizes() throws Exception { + IdentitySizeEstimator underlying = new IdentitySizeEstimator(); + SizeEstimator estimator = + new SamplingSizeEstimator(underlying, 0.05, 1.0, 10, new Random(1)); + // First 10 elements are always sampled. + for (int k = 0; k < 10; k += 2) { + assertEquals(100, estimator.estimateSize(100L)); + assertEquals(102, estimator.estimateSize(102L)); + assertEquals(k + 2, underlying.calls); + } + // Next 10 are sometimes sampled. + for (int k = 10; k < 20; k += 2) { + assertThat(estimator.estimateSize(100L), between(100L, 102L)); + assertThat(estimator.estimateSize(102L), between(100L, 102L)); + } + assertThat(underlying.calls, between(11, 19)); + int initialCalls = underlying.calls; + // Next 1000 are sampled at about 5%. + for (int k = 20; k < 1020; k += 2) { + assertThat(estimator.estimateSize(100L), between(100L, 102L)); + assertThat(estimator.estimateSize(102L), between(100L, 102L)); + } + assertThat(underlying.calls - initialCalls, between(40, 60)); + } + + @Test + public void testSampleHighVarianceSizes() throws Exception { + // The largest element is much larger than the average. + List sizes = Arrays.asList(1L, 10L, 100L, 1000L); + IdentitySizeEstimator underlying = new IdentitySizeEstimator(); + SizeEstimator estimator = + new SamplingSizeEstimator(underlying, 0.1, 0.2, 10, new Random(1)); + // First 10 elements are always sampled. + for (int k = 0; k < 10; k++) { + long size = sizes.get(k % sizes.size()); + assertEquals(size, estimator.estimateSize(size)); + assertEquals(k + 1, underlying.calls); + } + // We're still not out of the woods; sample every element. + for (int k = 10; k < 20; k++) { + long size = sizes.get(k % sizes.size()); + assertEquals(size, estimator.estimateSize(size)); + assertEquals(k + 1, underlying.calls); + } + // Sample some more to let things settle down. + for (int k = 20; k < 500; k++) { + estimator.estimateSize(sizes.get(k % sizes.size())); + } + // Next 1000 are sampled at about 20% (maxSampleRate). + int initialCalls = underlying.calls; + for (int k = 500; k < 1500; k++) { + long size = sizes.get(k % sizes.size()); + assertThat(estimator.estimateSize(size), + anyOf(isIn(sizes), between(250L, 350L))); + } + assertThat(underlying.calls - initialCalls, between(180, 220)); + // Sample some more to let things settle down. + for (int k = 1500; k < 3000; k++) { + estimator.estimateSize(sizes.get(k % sizes.size())); + } + // Next 1000 are sampled at about 10% (minSampleRate). + initialCalls = underlying.calls; + for (int k = 3000; k < 4000; k++) { + long size = sizes.get(k % sizes.size()); + assertThat(estimator.estimateSize(size), + anyOf(isIn(sizes), between(250L, 350L))); + } + assertThat(underlying.calls - initialCalls, between(90, 110)); + } + + @Test + public void testSampleChangingSizes() throws Exception { + IdentitySizeEstimator underlying = new IdentitySizeEstimator(); + SizeEstimator estimator = + new SamplingSizeEstimator(underlying, 0.05, 1.0, 10, new Random(1)); + // First 10 elements are always sampled. + for (int k = 0; k < 10; k++) { + assertEquals(100, estimator.estimateSize(100L)); + assertEquals(k + 1, underlying.calls); + } + // Next 10 are sometimes sampled. + for (int k = 10; k < 20; k++) { + assertEquals(100, estimator.estimateSize(100L)); + } + assertThat(underlying.calls, between(11, 19)); + int initialCalls = underlying.calls; + // Next 1000 are sampled at about 5%. + for (int k = 20; k < 1020; k++) { + assertEquals(100, estimator.estimateSize(100L)); + } + assertThat(underlying.calls - initialCalls, between(40, 60)); + // Inject a big element until it is sampled. + while (estimator.estimateSize(1000000L) == 100) { } + // Check that we have started sampling more regularly again. + assertEquals(99, estimator.estimateSize(99L)); + } + + private static > TypeSafeDiagnosingMatcher + between(final T min, final T max) { + return new TypeSafeDiagnosingMatcher() { + @Override + public void describeTo(Description description) { + description.appendText("is between " + min + " and " + max); + } + @Override + protected boolean matchesSafely(T item, Description mismatchDescription) { + return min.compareTo(item) <= 0 && item.compareTo(max) <= 0; + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperationTest.java new file mode 100644 index 000000000000..b3e29f8e5cf5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ReadOperationTest.java @@ -0,0 +1,303 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.cloudProgressToSourceProgress; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourcePositionToCloudPosition; +import static com.google.cloud.dataflow.sdk.runners.worker.SourceTranslationUtils.sourceProgressToCloudProgress; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.everyItem; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; + +import com.google.api.services.dataflow.model.ApproximateProgress; +import com.google.api.services.dataflow.model.Position; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestReceiver; +import com.google.cloud.dataflow.sdk.util.common.worker.ExecutorTestUtils.TestSource; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Tests for ReadOperation. + */ +@RunWith(JUnit4.class) +public class ReadOperationTest { + private static final long ITERATIONS = 3L; + + /** + * The test Source for testing updating stop position and progress report. + * The number of read iterations is controlled by ITERATIONS. + */ + static class TestTextSource extends Source { + @Override + public SourceIterator iterator() { + return new TestTextSourceIterator(); + } + + class TestTextSourceIterator extends AbstractSourceIterator { + long offset = 0L; + List proposedPositions = + new ArrayList<>(); + + @Override + public boolean hasNext() { + return offset < ITERATIONS; + } + + @Override + public String next() { + if (hasNext()) { + offset++; + return "hi"; + } else { + throw new AssertionError("No next Element."); + } + } + + @Override + public Progress getProgress() { + com.google.api.services.dataflow.model.Position currentPosition = + new com.google.api.services.dataflow.model.Position(); + currentPosition.setByteOffset(offset); + + ApproximateProgress progress = new ApproximateProgress(); + progress.setPosition(currentPosition); + + return cloudProgressToSourceProgress(progress); + } + + @Override + public Position updateStopPosition(Progress proposedStopPosition) { + proposedPositions.add(sourceProgressToCloudProgress(proposedStopPosition).getPosition()); + // Actually no update happens, returns null. + return null; + } + } + } + + /** + * The OutputReceiver for testing updating stop position and progress report. + * The offset of the Source (iterator) will be advanced each time this + * Receiver processes a record. + */ + static class TestTextReceiver extends OutputReceiver { + ReadOperation readOperation = null; + com.google.api.services.dataflow.model.Position proposedStopPosition = null; + List progresses = new ArrayList<>(); + + public TestTextReceiver(CounterSet counterSet, String counterPrefix) { + super("test_receiver_out", counterPrefix, counterSet.getAddCounterMutator()); + } + + public void setReadOperation(ReadOperation readOp) { + this.readOperation = readOp; + } + + public void setProposedStopPosition(com.google.api.services.dataflow.model.Position position) { + this.proposedStopPosition = position; + } + + @Override + public void process(Object outputElem) throws Exception { + // Calls getProgress() and proposeStopPosition() in each iteration. + progresses.add(sourceProgressToCloudProgress(readOperation.getProgress())); + // We expect that call to proposeStopPosition is a no-op that does not + // update the stop position for every iteration. We will verify it is + // delegated to SourceIterator after ReadOperation finishes. + Assert.assertNull( + readOperation.proposeStopPosition( + cloudProgressToSourceProgress(makeApproximateProgress(proposedStopPosition)))); + } + } + + @Test + public void testRunReadOperation() throws Exception { + TestSource source = new TestSource(); + source.addInput("hi", "there", "", "bob"); + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, counterSet.getAddCounterMutator()); + TestReceiver receiver = new TestReceiver(counterSet, counterPrefix); + + ReadOperation readOperation = new ReadOperation( + source, receiver, counterPrefix, counterSet.getAddCounterMutator(), stateSampler); + + readOperation.start(); + readOperation.finish(); + + Assert.assertThat( + receiver.outputElems, CoreMatchers.hasItems("hi", "there", "", "bob")); + + Assert.assertEquals( + new CounterSet( + Counter.longs("ReadOperation-ByteCount", SUM).resetToValue(2L + 5 + 0 + 3), + Counter.longs("test_receiver_out-ElementCount", SUM).resetToValue(4L), + Counter.longs("test_receiver_out-MeanByteCount", MEAN).resetToValue(4, 10L), + Counter.longs("test-ReadOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ReadOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-ReadOperation-read-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ReadOperation-read-msecs")).getAggregate(false)), + Counter.longs("test-ReadOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ReadOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-ReadOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-ReadOperation-finish-msecs")).getAggregate(false))), + counterSet); + } + + @Test + public void testGetProgressAndProposeStopPosition() throws Exception { + TestTextSource testSource = new TestTextSource(); + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, counterSet.getAddCounterMutator()); + TestTextReceiver receiver = new TestTextReceiver(counterSet, counterPrefix); + ReadOperation readOperation = new ReadOperation( + testSource, receiver, counterPrefix, counterSet.getAddCounterMutator(), stateSampler); + readOperation.setProgressUpdatePeriodMs(0); + receiver.setReadOperation(readOperation); + + Position proposedStopPosition = makePosition(3L); + receiver.setProposedStopPosition(proposedStopPosition); + + Assert.assertNull(readOperation.getProgress()); + Assert.assertNull(readOperation.proposeStopPosition( + cloudProgressToSourceProgress( + makeApproximateProgress(proposedStopPosition)))); + + readOperation.start(); + readOperation.finish(); + + TestTextSource.TestTextSourceIterator testIterator = + (TestTextSource.TestTextSourceIterator) readOperation.sourceIterator; + + Assert.assertEquals(sourceProgressToCloudProgress(testIterator.getProgress()), + sourceProgressToCloudProgress(readOperation.getProgress())); + Assert.assertEquals(sourcePositionToCloudPosition(testIterator.updateStopPosition( + cloudProgressToSourceProgress( + makeApproximateProgress(proposedStopPosition)))), + sourcePositionToCloudPosition(readOperation.proposeStopPosition( + cloudProgressToSourceProgress( + makeApproximateProgress(proposedStopPosition))))); + + // Verifies progress report and stop position updates. + Assert.assertEquals(testIterator.proposedPositions.size(), ITERATIONS + 2); + Assert.assertThat( + testIterator.proposedPositions, everyItem(equalTo(makePosition(3L)))); + Assert.assertThat( + receiver.progresses, contains(makeApproximateProgress(1L), makeApproximateProgress(2L), + makeApproximateProgress(3L))); + } + + @Test + public void testGetProgressDoesNotBlock() throws Exception { + final BlockingQueue queue = new LinkedBlockingQueue<>(); + final Source.SourceIterator iterator = new Source.AbstractSourceIterator() { + private int itemsReturned = 0; + + @Override + public boolean hasNext() throws IOException { + return itemsReturned < 5; + } + + @Override + public Integer next() throws IOException { + ++itemsReturned; + try { + return queue.take(); + } catch (InterruptedException e) { + throw new NoSuchElementException("interrupted"); + } + } + + @Override + public Source.Progress getProgress() { + return cloudProgressToSourceProgress(new ApproximateProgress().setPosition( + new Position().setRecordIndex((long) itemsReturned))); + } + }; + + Source source = new Source() { + @Override + public SourceIterator iterator() throws IOException { + return iterator; + } + }; + + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler(counterPrefix, counterSet.getAddCounterMutator()); + TestTextReceiver receiver = new TestTextReceiver(counterSet, counterPrefix); + final ReadOperation readOperation = new ReadOperation( + source, receiver, counterPrefix, counterSet.getAddCounterMutator(), stateSampler); + // Update progress not continuously, but so that it's never more than 1 record stale. + readOperation.setProgressUpdatePeriodMs(150); + receiver.setReadOperation(readOperation); + + new Thread() { + @Override + public void run() { + try { + readOperation.start(); + readOperation.finish(); + } catch (Exception e) { + e.printStackTrace(); + } + } + }.start(); + + for (int i = 0; i < 5; ++i) { + Thread.sleep(100); // Wait for the operation to start and block. + // Ensure that getProgress() doesn't block. + ApproximateProgress progress = sourceProgressToCloudProgress(readOperation.getProgress()); + long observedIndex = progress.getPosition().getRecordIndex().longValue(); + Assert.assertTrue("Actual: " + observedIndex, i == observedIndex || i == observedIndex + 1); + queue.offer(i); + } + } + + private static Position makePosition(long offset) { + return new Position().setByteOffset(offset); + } + + private static ApproximateProgress makeApproximateProgress(long offset) { + return makeApproximateProgress(makePosition(offset)); + } + + private static ApproximateProgress makeApproximateProgress( + com.google.api.services.dataflow.model.Position position) { + return new ApproximateProgress().setPosition(position); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryTest.java new file mode 100644 index 000000000000..10e3b4da63f4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/ShuffleEntryTest.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ShuffleEntry}. */ +@RunWith(JUnit4.class) +public class ShuffleEntryTest { + private static final byte[] KEY = {0xA}; + private static final byte[] SKEY = {0xB}; + private static final byte[] VALUE = {0xC}; + + @Test + public void accessors() { + ShuffleEntry entry = new ShuffleEntry(KEY, SKEY, VALUE); + assertThat(entry.getKey(), equalTo(KEY)); + assertThat(entry.getSecondaryKey(), equalTo(SKEY)); + assertThat(entry.getValue(), equalTo(VALUE)); + } + + @Test + public void equalsToItself() { + ShuffleEntry entry = new ShuffleEntry(KEY, SKEY, VALUE); + assertTrue(entry.equals(entry)); + } + + @Test + public void equalsForEqualEntries() { + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry( + KEY.clone(), SKEY.clone(), VALUE.clone()); + + assertTrue(entry0.equals(entry1)); + assertTrue(entry1.equals(entry0)); + assertEquals(entry0.hashCode(), entry1.hashCode()); + } + + @Test + public void equalsForEqualNullEntries() { + ShuffleEntry entry0 = new ShuffleEntry(null, null, null); + ShuffleEntry entry1 = new ShuffleEntry(null, null, null); + + assertTrue(entry0.equals(entry1)); + assertTrue(entry1.equals(entry0)); + assertEquals(entry0.hashCode(), entry1.hashCode()); + } + + @Test + public void notEqualsWhenKeysDiffer() { + final byte[] otherKey = {0x1}; + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(otherKey, SKEY, VALUE); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void notEqualsWhenKeysDifferOneNull() { + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(null, SKEY, VALUE); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void notEqualsWhenSecondaryKeysDiffer() { + final byte[] otherSKey = {0x2}; + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(KEY, otherSKey, VALUE); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void notEqualsWhenSecondaryKeysDifferOneNull() { + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(KEY, null, VALUE); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void notEqualsWhenValuesDiffer() { + final byte[] otherValue = {0x2}; + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(KEY, SKEY, otherValue); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void notEqualsWhenValuesDifferOneNull() { + ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE); + ShuffleEntry entry1 = new ShuffleEntry(KEY, SKEY, null); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } + + @Test + public void emptyNotTheSameAsNull() { + final byte[] empty = {}; + ShuffleEntry entry0 = new ShuffleEntry(null, null, null); + ShuffleEntry entry1 = new ShuffleEntry(empty, empty, empty); + + assertFalse(entry0.equals(entry1)); + assertFalse(entry1.equals(entry0)); + assertThat(entry0.hashCode(), not(equalTo(entry1.hashCode()))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSamplerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSamplerTest.java new file mode 100644 index 000000000000..d350db1798bf --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSamplerTest.java @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Map; + +/** + * Unit tests for the {@link Counter} API. + */ +@RunWith(JUnit4.class) +public class StateSamplerTest { + + @Test + public void basicTest() throws InterruptedException { + CounterSet counters = new CounterSet(); + long periodMs = 50; + StateSampler stateSampler = new StateSampler("test-", + counters.getAddCounterMutator(), periodMs); + + int state1 = stateSampler.stateForName("1"); + int state2 = stateSampler.stateForName("2"); + + assertEquals(new SimpleEntry<>("", 0L), + stateSampler.getCurrentStateAndDuration()); + + try (StateSampler.ScopedState s1 = + stateSampler.scopedState(state1)) { + Thread.sleep(2 * periodMs); + } + + try (StateSampler.ScopedState s2 = + stateSampler.scopedState(state2)) { + Thread.sleep(3 * periodMs); + } + + long s1 = stateSampler.getStateDuration(state1); + long s2 = stateSampler.getStateDuration(state2); + + System.out.println("basic s1: " + s1); + System.out.println("basic s2: " + s2); + + long toleranceMs = periodMs; + assertTrue(s1 + s2 >= 4 * periodMs - toleranceMs); + assertTrue(s1 + s2 <= 10 * periodMs + toleranceMs); + } + + @Test + public void nestingTest() throws InterruptedException { + CounterSet counters = new CounterSet(); + long periodMs = 50; + StateSampler stateSampler = new StateSampler("test-", + counters.getAddCounterMutator(), periodMs); + + int state1 = stateSampler.stateForName("1"); + int state2 = stateSampler.stateForName("2"); + int state3 = stateSampler.stateForName("3"); + + assertEquals(new SimpleEntry<>("", 0L), + stateSampler.getCurrentStateAndDuration()); + + try (StateSampler.ScopedState s1 = + stateSampler.scopedState(state1)) { + Thread.sleep(2 * periodMs); + + try (StateSampler.ScopedState s2 = + stateSampler.scopedState(state2)) { + Thread.sleep(2 * periodMs); + + try (StateSampler.ScopedState s3 = + stateSampler.scopedState(state3)) { + Thread.sleep(2 * periodMs); + } + Thread.sleep(periodMs); + } + Thread.sleep(periodMs); + } + + long s1 = stateSampler.getStateDuration(state1); + long s2 = stateSampler.getStateDuration(state2); + long s3 = stateSampler.getStateDuration(state3); + + System.out.println("s1: " + s1); + System.out.println("s2: " + s2); + System.out.println("s3: " + s3); + + long toleranceMs = periodMs; + assertTrue(s1 + s2 + s3 >= 4 * periodMs - toleranceMs); + assertTrue(s1 + s2 + s3 <= 16 * periodMs + toleranceMs); + } + + @Test + public void nonScopedTest() throws InterruptedException { + CounterSet counters = new CounterSet(); + long periodMs = 50; + StateSampler stateSampler = new StateSampler("test-", + counters.getAddCounterMutator(), periodMs); + + int state1 = stateSampler.stateForName("1"); + int previousState = stateSampler.setState(state1); + Thread.sleep(2 * periodMs); + Map.Entry currentStateAndDuration = + stateSampler.getCurrentStateAndDuration(); + stateSampler.setState(previousState); + assertEquals("test-1-msecs", currentStateAndDuration.getKey()); + long tolerance = periodMs; + long s = currentStateAndDuration.getValue(); + System.out.println("s: " + s); + assertTrue(s >= periodMs - tolerance); + assertTrue(s <= 4 * periodMs + tolerance); + + assertTrue(stateSampler.getCurrentStateAndDuration() + .getKey().isEmpty()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutorTest.java new file mode 100644 index 000000000000..ecce00d68b76 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WorkExecutorTest.java @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; + +import com.google.cloud.dataflow.sdk.util.common.Metric; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Collection; + +/** + * Unit tests for {@link WorkExecutor}. + */ +@RunWith(JUnit4.class) +public class WorkExecutorTest { + private WorkExecutor mapWorker; + private WorkExecutor seqMapWorker; + + @Before + public void setUp() { + mapWorker = new MapTaskExecutor(null, null, null); + } + + @Test + public void testMapTaskGetOutputMetrics() { + Collection> metrics = mapWorker.getOutputMetrics(); + verifyOutputMetrics(metrics); + } + + private void verifyOutputMetrics(Collection> metrics) { + Collection metricNames = new ArrayList<>(); + for (Metric metric : metrics) { + metricNames.add(metric.getName()); + } + Assert.assertThat(metricNames, containsInAnyOrder("CPU")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperationTest.java new file mode 100644 index 000000000000..6b51bc603531 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/worker/WriteOperationTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for WriteOperation. + */ +@RunWith(JUnit4.class) +public class WriteOperationTest { + @Test + public void testRunWriteOperation() throws Exception { + ExecutorTestUtils.TestSink sink = new ExecutorTestUtils.TestSink(); + CounterSet counterSet = new CounterSet(); + String counterPrefix = "test-"; + StateSampler stateSampler = new StateSampler( + counterPrefix, counterSet.getAddCounterMutator()); + + WriteOperation writeOperation = new WriteOperation( + sink, counterPrefix, counterSet.getAddCounterMutator(), stateSampler); + + writeOperation.start(); + + writeOperation.process("hi"); + writeOperation.process("there"); + writeOperation.process(""); + writeOperation.process("bob"); + + writeOperation.finish(); + + Assert.assertThat(sink.outputElems, + CoreMatchers.hasItems("hi", "there", "", "bob")); + + Assert.assertEquals( + new CounterSet( + Counter.longs("WriteOperation-ByteCount", SUM) + .resetToValue(2L + 5 + 0 + 3), + Counter.longs("test-WriteOperation-start-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-start-msecs")).getAggregate(false)), + Counter.longs("test-WriteOperation-process-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-process-msecs")).getAggregate(false)), + Counter.longs("test-WriteOperation-finish-msecs", SUM) + .resetToValue(((Counter) counterSet.getExistingCounter( + "test-WriteOperation-finish-msecs")).getAggregate(false))), + counterSet); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java new file mode 100644 index 000000000000..9904bd5a2428 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsfs; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * Tests of GcsPath. + */ +@RunWith(JUnit4.class) +public class GcsPathTest { + + /** + * Test case, which tests parsing and building of GcsPaths. + */ + static final class TestCase { + + final String uri; + final String expectedBucket; + final String expectedObject; + final String[] namedComponents; + + TestCase(String uri, String... namedComponents) { + this.uri = uri; + this.expectedBucket = namedComponents[0]; + this.namedComponents = namedComponents; + this.expectedObject = uri.substring(expectedBucket.length() + 6); + } + } + + // Each test case is an expected URL, then the components used to build it. + // Empty components result in a double slash. + static final List PATH_TEST_CASES = Arrays.asList( + new TestCase("gs://bucket/then/object", "bucket", "then", "object"), + new TestCase("gs://bucket//then/object", "bucket", "", "then", "object"), + new TestCase("gs://bucket/then//object", "bucket", "then", "", "object"), + new TestCase("gs://bucket/then///object", "bucket", "then", "", "", "object"), + new TestCase("gs://bucket/then/object/", "bucket", "then", "object/"), + new TestCase("gs://bucket/then/object/", "bucket", "then/", "object/"), + new TestCase("gs://bucket/then/object//", "bucket", "then", "object", ""), + new TestCase("gs://bucket/then/object//", "bucket", "then", "object/", ""), + new TestCase("gs://bucket/", "bucket") + ); + + @Test + public void testGcsPathParsing() throws IOException { + for (TestCase testCase : PATH_TEST_CASES) { + String uriString = testCase.uri; + + GcsPath path = GcsPath.fromUri(URI.create(uriString)); + // Deconstruction - check bucket, object, and components. + assertEquals(testCase.expectedBucket, path.getBucket()); + assertEquals(testCase.expectedObject, path.getObject()); + assertEquals(testCase.uri, + testCase.namedComponents.length, path.getNameCount()); + + // Construction - check that the path can be built from components. + GcsPath built = GcsPath.fromComponents(null, null); + for (String component : testCase.namedComponents) { + built = built.resolve(component); + } + assertEquals(testCase.uri, built.toString()); + } + } + + @Test + public void testParentRelationship() throws IOException { + GcsPath path = GcsPath.fromComponents("bucket", "then/object"); + assertEquals("bucket", path.getBucket()); + assertEquals("then/object", path.getObject()); + assertEquals(3, path.getNameCount()); + assertTrue(path.endsWith("object")); + assertTrue(path.startsWith("bucket/then")); + + GcsPath parent = path.getParent(); // gs://bucket/then/ + assertEquals("bucket", parent.getBucket()); + assertEquals("then/", parent.getObject()); + assertEquals(2, parent.getNameCount()); + assertThat(path, Matchers.not(Matchers.equalTo(parent))); + assertTrue(path.startsWith(parent)); + assertFalse(parent.startsWith(path)); + assertTrue(parent.endsWith("then/")); + assertTrue(parent.startsWith("bucket/then")); + assertTrue(parent.isAbsolute()); + + GcsPath root = path.getRoot(); + assertEquals(0, root.getNameCount()); + assertEquals("gs://", root.toString()); + assertEquals("", root.getBucket()); + assertEquals("", root.getObject()); + assertTrue(root.isAbsolute()); + assertThat(root, Matchers.equalTo(parent.getRoot())); + + GcsPath grandParent = parent.getParent(); // gs://bucket/ + assertEquals(1, grandParent.getNameCount()); + assertEquals("gs://bucket/", grandParent.toString()); + assertTrue(grandParent.isAbsolute()); + assertThat(root, Matchers.equalTo(grandParent.getParent())); + assertThat(root.getParent(), Matchers.nullValue()); + + assertTrue(path.startsWith(path.getRoot())); + assertTrue(parent.startsWith(path.getRoot())); + } + + @Test + public void testRelativeParent() throws IOException { + GcsPath path = GcsPath.fromComponents(null, "a/b"); + GcsPath parent = path.getParent(); + assertEquals("a/", parent.toString()); + + GcsPath grandParent = parent.getParent(); + assertNull(grandParent); + } + + @Test + public void testUriSupport() throws IOException { + URI uri = URI.create("gs://bucket/some/path"); + + GcsPath path = GcsPath.fromUri(uri); + assertEquals("bucket", path.getBucket()); + assertEquals("some/path", path.getObject()); + + URI reconstructed = path.toUri(); + assertEquals(uri, reconstructed); + + path = GcsPath.fromUri("gs://bucket"); + assertEquals("gs://bucket/", path.toString()); + } + + @Test + public void testBucketParsing() throws IOException { + GcsPath path = GcsPath.fromUri("gs://bucket"); + GcsPath path2 = GcsPath.fromUri("gs://bucket/"); + + assertEquals(path, path2); + assertEquals(path.toString(), path2.toString()); + assertEquals(path.toUri(), path2.toUri()); + } + + @Test + public void testGcsPathToString() throws Exception { + String filename = "gs://some_bucket/some/file.txt"; + GcsPath path = GcsPath.fromUri(filename); + assertEquals(filename, path.toString()); + } + + @Test + public void testEquals() { + GcsPath a = GcsPath.fromComponents(null, "a/b/c"); + GcsPath a2 = GcsPath.fromComponents(null, "a/b/c"); + assertFalse(a.isAbsolute()); + assertFalse(a2.isAbsolute()); + + GcsPath b = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath b2 = GcsPath.fromComponents("bucket", "a/b/c"); + assertTrue(b.isAbsolute()); + assertTrue(b2.isAbsolute()); + + assertEquals(a, a); + assertThat(a, Matchers.not(Matchers.equalTo(b))); + assertThat(b, Matchers.not(Matchers.equalTo(a))); + + assertEquals(a, a2); + assertEquals(a2, a); + assertEquals(b, b2); + assertEquals(b2, b); + + assertThat(a, Matchers.not(Matchers.equalTo(Paths.get("/tmp/foo")))); + assertTrue(a != null); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidGcsPath() { + @SuppressWarnings("unused") + GcsPath filename = + GcsPath.fromUri("file://invalid/gcs/path"); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidBucket() { + GcsPath.fromComponents("invalid/", ""); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidObject_newline() { + GcsPath.fromComponents(null, "a\nb"); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidObject_cr() { + GcsPath.fromComponents(null, "a\rb"); + } + + @Test + public void testResolveUri() { + GcsPath path = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath d = path.resolve("gs://bucket2/d"); + assertEquals("gs://bucket2/d", d.toString()); + } + + @Test + public void testResolveOther() { + GcsPath a = GcsPath.fromComponents("bucket", "a"); + GcsPath b = a.resolve(Paths.get("b")); + assertEquals("a/b", b.getObject()); + } + + @Test + public void testCompareTo() { + GcsPath a = GcsPath.fromComponents("bucket", "a"); + GcsPath b = GcsPath.fromComponents("bucket", "b"); + GcsPath b2 = GcsPath.fromComponents("bucket2", "b"); + GcsPath brel = GcsPath.fromComponents(null, "b"); + GcsPath a2 = GcsPath.fromComponents("bucket", "a"); + GcsPath arel = GcsPath.fromComponents(null, "a"); + + assertThat(a.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(a), Matchers.greaterThan(0)); + assertThat(a.compareTo(a2), Matchers.equalTo(0)); + + assertThat(a.hashCode(), Matchers.equalTo(a2.hashCode())); + assertThat(a.hashCode(), Matchers.not(Matchers.equalTo(b.hashCode()))); + assertThat(b.hashCode(), Matchers.not(Matchers.equalTo(brel.hashCode()))); + + assertThat(brel.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(brel), Matchers.greaterThan(0)); + assertThat(arel.compareTo(brel), Matchers.lessThan(0)); + assertThat(brel.compareTo(arel), Matchers.greaterThan(0)); + + assertThat(b.compareTo(b2), Matchers.lessThan(0)); + assertThat(b2.compareTo(b), Matchers.greaterThan(0)); + } + + @Test + public void testCompareTo_ordering() { + GcsPath ab = GcsPath.fromComponents("bucket", "a/b"); + GcsPath abc = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath a1b = GcsPath.fromComponents("bucket", "a-1/b"); + + assertThat(ab.compareTo(a1b), Matchers.lessThan(0)); + assertThat(a1b.compareTo(ab), Matchers.greaterThan(0)); + + assertThat(ab.compareTo(abc), Matchers.lessThan(0)); + assertThat(abc.compareTo(ab), Matchers.greaterThan(0)); + } + + @Test + public void testCompareTo_buckets() { + GcsPath a = GcsPath.fromComponents(null, "a/b/c"); + GcsPath b = GcsPath.fromComponents("bucket", "a/b/c"); + + assertThat(a.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(a), Matchers.greaterThan(0)); + } + + @Test + public void testIterator() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c"); + Iterator it = a.iterator(); + + assertTrue(it.hasNext()); + assertEquals("gs://bucket/", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("a", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("b", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("c", it.next().toString()); + assertFalse(it.hasNext()); + } + + @Test + public void testSubpath() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + assertThat(a.subpath(0, 1).toString(), Matchers.equalTo("gs://bucket/")); + assertThat(a.subpath(0, 2).toString(), Matchers.equalTo("gs://bucket/a")); + assertThat(a.subpath(0, 3).toString(), Matchers.equalTo("gs://bucket/a/b")); + assertThat(a.subpath(0, 4).toString(), Matchers.equalTo("gs://bucket/a/b/c")); + assertThat(a.subpath(1, 2).toString(), Matchers.equalTo("a")); + assertThat(a.subpath(2, 3).toString(), Matchers.equalTo("b")); + assertThat(a.subpath(2, 4).toString(), Matchers.equalTo("b/c")); + assertThat(a.subpath(2, 5).toString(), Matchers.equalTo("b/c/d")); + } + + @Test + public void testGetName() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + assertEquals(5, a.getNameCount()); + assertThat(a.getName(0).toString(), Matchers.equalTo("gs://bucket/")); + assertThat(a.getName(1).toString(), Matchers.equalTo("a")); + assertThat(a.getName(2).toString(), Matchers.equalTo("b")); + assertThat(a.getName(3).toString(), Matchers.equalTo("c")); + assertThat(a.getName(4).toString(), Matchers.equalTo("d")); + } + + @Test(expected = IllegalArgumentException.class) + public void testSubPathError() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + a.subpath(1, 1); // throws IllegalArgumentException + Assert.fail(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListenerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListenerTest.java new file mode 100644 index 000000000000..96e5bf6b4988 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsio/LoggingMediaHttpUploaderProgressListenerTest.java @@ -0,0 +1,83 @@ +/** + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsio; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; + +import com.google.api.client.googleapis.media.MediaHttpUploader.UploadState; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.slf4j.Logger; + +/** Unit tests for {@link LoggingMediaHttpUploaderProgressListener}. */ +@RunWith(JUnit4.class) +public class LoggingMediaHttpUploaderProgressListenerTest { + @Mock + private Logger mockLogger; + private LoggingMediaHttpUploaderProgressListener listener; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + listener = new LoggingMediaHttpUploaderProgressListener("NAME", 60000L); + } + + @Test + public void testLoggingInitiation() { + listener.progressChanged(mockLogger, UploadState.INITIATION_STARTED, 0L, 0L); + verify(mockLogger).info("Uploading: {}", "NAME"); + verifyNoMoreInteractions(mockLogger); + } + + @Test + public void testLoggingProgressAfterSixtySeconds() { + listener.progressChanged(mockLogger, UploadState.MEDIA_IN_PROGRESS, 10485760L, 60001L); + listener.progressChanged(mockLogger, UploadState.MEDIA_IN_PROGRESS, 104857600L, 120002L); + verify(mockLogger).info( + "Uploading: NAME Average Rate: 0.167 MiB/s, Current Rate: 0.167 MiB/s, Total: 10.000 MiB"); + verify(mockLogger).info( + "Uploading: NAME Average Rate: 0.833 MiB/s, Current Rate: 1.500 MiB/s, Total: 100.000 MiB"); + verifyNoMoreInteractions(mockLogger); + } + + @Test + public void testSkippingLoggingAnInProgressUpdate() { + listener.progressChanged(mockLogger, UploadState.MEDIA_IN_PROGRESS, 104857600L, 60000L); + verifyZeroInteractions(mockLogger); + } + + @Test + public void testLoggingCompletion() { + listener.progressChanged(mockLogger, UploadState.MEDIA_COMPLETE, 104857600L, 60000L); + verify(mockLogger).info("Finished Uploading: {}", "NAME"); + verifyNoMoreInteractions(mockLogger); + } + + @Test + public void testOtherUpdatesIgnored() { + listener.progressChanged(mockLogger, UploadState.NOT_STARTED, 0L, 60001L); + listener.progressChanged(mockLogger, UploadState.INITIATION_COMPLETE, 0L, 60001L); + verifyZeroInteractions(mockLogger); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java new file mode 100644 index 000000000000..dae544fb033a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Comparator; + +/** + * Tests for KV. + */ +@RunWith(JUnit4.class) +public class KVTest { + static final Integer testValues[] = + {null, Integer.MIN_VALUE, -1, 0, 1, Integer.MAX_VALUE}; + + // Wrapper around Integer.compareTo() to support null values. + private int compareInt(Integer a, Integer b) { + if (a == null) { + return b == null ? 0 : -1; + } else { + return b == null ? 1 : a.compareTo(b); + } + } + + @Test + public void testOrderByKey() { + Comparator> orderByKey = new KV.OrderByKey<>(); + for (Integer key1 : testValues) { + for (Integer val1 : testValues) { + for (Integer key2 : testValues) { + for (Integer val2 : testValues) { + assertEquals(compareInt(key1, key2), + orderByKey.compare(KV.of(key1, val1), KV.of(key2, val2))); + } + } + } + } + } + + @Test + public void testOrderByValue() { + Comparator> orderByValue = new KV.OrderByValue<>(); + for (Integer key1 : testValues) { + for (Integer val1 : testValues) { + for (Integer key2 : testValues) { + for (Integer val2 : testValues) { + assertEquals(compareInt(val1, val2), + orderByValue.compare(KV.of(key1, val1), KV.of(key2, val2))); + } + } + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java new file mode 100644 index 000000000000..a6c180fc9abe --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; + +/** + * Tests for PCollectionLists. + */ +@RunWith(JUnit4.class) +public class PCollectionListTest { + @Test + public void testEmptyListFailure() { + try { + PCollectionList.of(Collections.>emptyList()); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + assertThat( + exn.toString(), + containsString( + "must either have a non-empty list of PCollections, " + + "or must first call empty(Pipeline)")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java new file mode 100644 index 000000000000..e886f350c12d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; + +/** + * Tests for PDone. + */ +@RunWith(JUnit4.class) +public class PDoneTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + /** + * A PTransform that just returns a fresh PDone. + */ + static class EmptyTransform extends PTransform { + @Override + public PDone apply(PBegin begin) { + return new PDone(); + } + } + + /** + * A PTransform that's composed of something that returns a PDone. + */ + static class SimpleTransform extends PTransform { + private final String filename; + + public SimpleTransform(String filename) { + this.filename = filename; + } + + @Override + public PDone apply(PBegin begin) { + return + begin + .apply(Create.of(LINES)) + .apply(TextIO.Write.to(filename)); + } + } + + // TODO: This test doesn't work, because we can't handle composite + // transforms that contain no nested transforms. + // @Test + // @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void DISABLED_testEmptyTransform() { + Pipeline p = TestPipeline.create(); + + p.begin().apply(new EmptyTransform()); + + p.run(); + } + + // Cannot run on the service, unless we allocate a GCS temp file + // instead of a local temp file. Or switch to applying a different + // transform that returns PDone. + @Test + public void testSimpleTransform() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + Pipeline p = TestPipeline.create(); + + p.begin().apply(new SimpleTransform(filename)); + + p.run(); + } +}

To allow such anonymous {@code *Fn}s to be written + * conveniently, {@code PTransform} is marked as {@code Serializable}, + * and includes dummy {@code writeObject()} and {@code readObject()} + * operations that do not save or restore any state. + * + * @see Applying Transformations + * + * @param the type of the input to this PTransform + * @param the type of the output of this PTransform + */ +public abstract class PTransform + implements Serializable /* See the note above */ { + + /** + * Applies this {@code PTransform} on the given {@code Input}, and returns its + * {@code Output}. + * + *