diff --git a/repl/repl.go b/repl/repl.go index be0f33fbf3..2303b3ec5c 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -18,6 +18,7 @@ import ( "strings" "sync" + "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/compile" "github.com/open-policy-agent/opa/version" @@ -52,6 +53,7 @@ type REPL struct { strictBuiltinErrors bool capabilities *ast.Capabilities v1Compatible bool + initBundles map[string]*bundle.Bundle // TODO(tsandall): replace this state with rule definitions // inside the default module. @@ -132,6 +134,11 @@ func (r *REPL) WithCapabilities(capabilities *ast.Capabilities) *REPL { return r } +func (r *REPL) WithInitBundles(b map[string]*bundle.Bundle) *REPL { + r.initBundles = b + return r +} + func defaultModule() *ast.Module { return ast.MustParseModule(`package repl`) } @@ -1249,15 +1256,27 @@ func (r *REPL) loadHistory(prompt *liner.State) { } func (r *REPL) loadModules(ctx context.Context, txn storage.Transaction) (map[string]*ast.Module, error) { + modules := make(map[string]*ast.Module) + + if len(r.initBundles) > 0 { + for bundleName, b := range r.initBundles { + for name, module := range b.ParsedModules(bundleName) { + modules[name] = module + } + } + } ids, err := r.store.ListPolicies(ctx, txn) if err != nil { return nil, err } - modules := make(map[string]*ast.Module, len(ids)) - for _, id := range ids { + // skip re-parsing + if _, haveMod := modules[id]; haveMod { + continue + } + bs, err := r.store.GetPolicy(ctx, txn, id) if err != nil { return nil, err diff --git a/runtime/runtime.go b/runtime/runtime.go index 456c5cefb4..e512c3b050 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -259,15 +259,17 @@ type Runtime struct { Store storage.Store Manager *plugins.Manager - logger logging.Logger - server *server.Server - metrics *prometheus.Provider - reporter *report.Reporter - traceExporter *otlptrace.Exporter + logger logging.Logger + server *server.Server + metrics *prometheus.Provider + reporter *report.Reporter + traceExporter *otlptrace.Exporter + loadedPathsResult *initload.LoadPathsResult serverInitialized bool serverInitMtx sync.RWMutex done chan struct{} + repl *repl.REPL } // NewRuntime returns a new Runtime object initialized with params. Clients must @@ -453,6 +455,7 @@ func NewRuntime(ctx context.Context, params Params) (*Runtime, error) { reporter: reporter, serverInitialized: false, traceExporter: traceExporter, + loadedPathsResult: loaded, } return rt, nil @@ -708,7 +711,8 @@ func (rt *Runtime) StartREPL(ctx context.Context) { banner := rt.getBanner() repl := repl.New(rt.Store, rt.Params.HistoryPath, rt.Params.Output, rt.Params.OutputFormat, rt.Params.ErrorLimit, banner). WithRuntime(rt.Manager.Info). - WithV1Compatible(rt.Params.V1Compatible) + WithV1Compatible(rt.Params.V1Compatible). + WithInitBundles(rt.loadedPathsResult.Bundles) if rt.Params.Watch { if err := rt.startWatcher(ctx, rt.Params.Paths, onReloadPrinter(rt.Params.Output)); err != nil { @@ -722,6 +726,8 @@ func (rt *Runtime) StartREPL(ctx context.Context) { repl.SetOPAVersionReport(rt.checkOPAUpdate(ctx).Slice()) }() } + + rt.repl = repl repl.Loop(ctx) } diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 5ab40fba17..f2e21f64c1 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -235,6 +235,65 @@ func testRuntimeProcessWatchEventPolicyError(t *testing.T, asBundle bool) { }) } +func TestRuntimeReplWithBundleBuiltWithV1Compatibility(t *testing.T) { + ctx := context.Background() + + test.WithTempFS(nil, func(rootDir string) { + p := filepath.Join(rootDir, "bundle.tar.gz") + + mod := `package test + p := 7 if 3 < 4 + ` + + files := [][2]string{ + {"/.manifest", `{"revision": "foo", "rego_version": 1}`}, + {"/x.rego", mod}, + } + + buf := archive.MustWriteTarGz(files) + bf, err := os.Create(p) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + _, err = bf.Write(buf.Bytes()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + output := test.BlockingWriter{} + + params := NewParams() + params.Output = &output + params.Paths = []string{p} + params.BundleMode = true + + rt, err := NewRuntime(ctx, params) + if err != nil { + t.Fatal(err) + } + + go rt.StartREPL(ctx) + + if !test.Eventually(t, 5*time.Second, func() bool { + return strings.Contains(output.String(), "Run 'help' to see a list of commands and check for updates.") + }) { + t.Fatal("Timed out waiting for REPL to start") + } + output.Reset() + + if err := rt.repl.OneShot(ctx, "data.test.p"); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + actual := strings.TrimSpace(output.String()) + expected := "7" + + if actual != expected { + t.Fatalf("expected data.test.p to be %v, got %v", expected, actual) + } + }) +} + func TestRuntimeReplProcessWatchV1Compatible(t *testing.T) { tests := []struct { note string