From 7c1c2509242f88a877ffa60f2ae8894710a6e7a1 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:44:10 -0400 Subject: [PATCH] Add use_public_ips to Go Dataflow Runner (#28308) --- sdks/go/pkg/beam/runners/dataflow/dataflow.go | 22 ++++++ .../beam/runners/dataflow/dataflow_test.go | 77 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go index 36e418fb5231..7b43ba78f054 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go @@ -69,6 +69,7 @@ var ( network = flag.String("network", "", "GCP network (optional)") subnetwork = flag.String("subnetwork", "", "GCP subnetwork (optional)") noUsePublicIPs = flag.Bool("no_use_public_ips", false, "Workers must not use public IP addresses (optional)") + usePublicIPs = flag.Bool("use_public_ips", true, "Workers must use public IP addresses (optional)") tempLocation = flag.String("temp_location", "", "Temp location (optional)") workerMachineType = flag.String("worker_machine_type", "", "GCE machine type (optional)") machineType = flag.String("machine_type", "", "alias of worker_machine_type (optional)") @@ -245,6 +246,16 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) return dataflowlib.Execute(ctx, model, opts, workerURL, modelURL, *endpoint, *jobopts.Async) } +func isFlagPassed(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + func getJobOptions(ctx context.Context, streaming bool) (*dataflowlib.JobOptions, error) { project := gcpopts.GetProjectFromFlagOrEnvironment(ctx) if project == "" { @@ -294,6 +305,17 @@ func getJobOptions(ctx context.Context, streaming bool) (*dataflowlib.JobOptions return nil, errors.Wrapf(err, "error reading --transform_name_mapping flag as JSON") } } + if *usePublicIPs == *noUsePublicIPs { + useSet := isFlagPassed("use_public_ips") + noUseSet := isFlagPassed("no_use_public_ips") + // If use_public_ips was explicitly set but no_use_public_ips was not, use that value + // We take the explicit value of no_use_public_ips if it was set but use_public_ips was not. + if useSet && !noUseSet { + *noUsePublicIPs = !*usePublicIPs + } else if useSet && noUseSet { + return nil, errors.New("exactly one of usePublicIPs and noUsePublicIPs must be true, please check that only one is true") + } + } hooks.SerializeHooksToOptions() diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow_test.go b/sdks/go/pkg/beam/runners/dataflow/dataflow_test.go index d3518964da8a..663695f00c8e 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflow_test.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflow_test.go @@ -427,6 +427,81 @@ func TestGetJobOptions_AliasAreEffective(t *testing.T) { } } +func TestGetJobOptions_BadTruePublicIPs(t *testing.T) { + resetGlobals() + *usePublicIPs = true + *noUsePublicIPs = true + + opts, err := getJobOptions(context.Background(), false) + if err == nil { + t.Error("getJobOptions() returned error nil, want an error") + } + if opts != nil { + t.Errorf("getJobOptions() returned JobOptions when it should not have, got %#v, want nil", opts) + } +} + +func TestGetJobOptions_BadFalsePublicIPs(t *testing.T) { + resetGlobals() + *usePublicIPs = false + *noUsePublicIPs = false + + opts, err := getJobOptions(context.Background(), false) + if err == nil { + t.Error("getJobOptions() returned error nil, want an error") + } + if opts != nil { + t.Errorf("getJobOptions() returned JobOptions when it should not have, got %#v, want nil", opts) + } +} + +func TestGetJobOptions_DefaultPublicIPs(t *testing.T) { + resetGlobals() + *labels = `{"label1": "val1", "label2": "val2"}` + *stagingLocation = "gs://testStagingLocation" + *minCPUPlatform = "testPlatform" + *flexRSGoal = "FLEXRS_SPEED_OPTIMIZED" + *dataflowServiceOptions = "opt1,opt2" + + *gcpopts.Project = "testProject" + *gcpopts.Region = "testRegion" + + *jobopts.Experiments = "use_runner_v2,use_portable_job_submission" + *jobopts.JobName = "testJob" + + opts, err := getJobOptions(context.Background(), false) + if err != nil { + t.Fatalf("getJobOptions() returned error %q, want %q", err, "nil") + } + if got, want := opts.NoUsePublicIPs, false; got != want { + t.Errorf("getJobOptions().NoUsePublicIPs = %t, want %t", got, want) + } +} + +func TestGetJobOptions_NoUsePublicIPs(t *testing.T) { + resetGlobals() + *labels = `{"label1": "val1", "label2": "val2"}` + *stagingLocation = "gs://testStagingLocation" + *minCPUPlatform = "testPlatform" + *flexRSGoal = "FLEXRS_SPEED_OPTIMIZED" + *dataflowServiceOptions = "opt1,opt2" + *noUsePublicIPs = true + + *gcpopts.Project = "testProject" + *gcpopts.Region = "testRegion" + + *jobopts.Experiments = "use_runner_v2,use_portable_job_submission" + *jobopts.JobName = "testJob" + + opts, err := getJobOptions(context.Background(), false) + if err != nil { + t.Fatalf("getJobOptions() returned error %q, want %q", err, "nil") + } + if got, want := opts.NoUsePublicIPs, true; got != want { + t.Errorf("getJobOptions().NoUsePublicIPs = %t, want %t", got, want) + } +} + func getFieldFromOpt(fieldName string, opts *dataflowlib.JobOptions) string { return reflect.ValueOf(opts).Elem().FieldByName(fieldName).String() } @@ -447,6 +522,8 @@ func resetGlobals() { *stagingLocation = "" *transformMapping = "" *update = false + *usePublicIPs = true + *noUsePublicIPs = false *workerHarnessImage = "" *workerMachineType = "" *machineType = ""