From c55c2f0df3dfc4839a8033158b8f9a9ce4cdc203 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Wed, 5 Feb 2025 10:49:03 -0800 Subject: [PATCH 1/7] add delete operations to drop command --- engine/core/processors_has.go | 5 ++- engine/core/util.go | 9 ----- gripql/inspect/inspect.go | 9 ----- kvgraph/graph.go | 71 ++++++++++++++++++++++++++++------- kvgraph/graphdb.go | 4 ++ kvgraph/test/index_test.go | 19 +++++----- kvgraph/test/main_test.go | 9 ----- kvindex/entries.go | 9 +++-- kvindex/keys.go | 8 +++- kvindex/kvindex.go | 23 +++++------- 10 files changed, 96 insertions(+), 70 deletions(-) diff --git a/engine/core/processors_has.go b/engine/core/processors_has.go index 756da192..ba1f25ab 100644 --- a/engine/core/processors_has.go +++ b/engine/core/processors_has.go @@ -6,6 +6,7 @@ import ( "github.com/bmeg/grip/engine/logic" "github.com/bmeg/grip/gdbi" "github.com/bmeg/grip/gripql" + "github.com/bmeg/grip/util/setcmp" ) //////////////////////////////////////////////////////////////////////////////// @@ -50,7 +51,7 @@ func (h *HasLabel) Process(ctx context.Context, man gdbi.Manager, in gdbi.InPipe out <- t continue } - if contains(labels, t.GetCurrent().Get().Label) { + if setcmp.ContainsString(labels, t.GetCurrent().Get().Label) { out <- t } } @@ -106,7 +107,7 @@ func (h *HasID) Process(ctx context.Context, man gdbi.Manager, in gdbi.InPipe, o out <- t continue } - if contains(ids, t.GetCurrentID()) { + if setcmp.ContainsString(ids, t.GetCurrentID()) { out <- t } } diff --git a/engine/core/util.go b/engine/core/util.go index 4ec66f79..2cc616ef 100644 --- a/engine/core/util.go +++ b/engine/core/util.go @@ -8,15 +8,6 @@ func debug(i ...interface{}) { pretty.Println(i...) } -func contains(a []string, v string) bool { - for _, i := range a { - if i == v { - return true - } - } - return false -} - func dedupStringSlice(s []string) []string { seen := make(map[string]struct{}, len(s)) j := 0 diff --git a/gripql/inspect/inspect.go b/gripql/inspect/inspect.go index e4355601..da4911d9 100644 --- a/gripql/inspect/inspect.go +++ b/gripql/inspect/inspect.go @@ -22,15 +22,6 @@ func arrayEq(a, b []string) bool { return true } -func contains(a []string, n string) bool { - for _, c := range a { - if c == n { - return true - } - } - return false -} - // PipelineSteps create an array, the same length at stmts that labels the // step id for each of the GraphStatements func PipelineSteps(stmts []*gripql.GraphStatement) []string { diff --git a/kvgraph/graph.go b/kvgraph/graph.go index 8739f2d3..422028fd 100644 --- a/kvgraph/graph.go +++ b/kvgraph/graph.go @@ -11,20 +11,12 @@ import ( "github.com/bmeg/grip/kvi" "github.com/bmeg/grip/kvindex" "github.com/bmeg/grip/log" + "github.com/bmeg/grip/util/setcmp" "google.golang.org/protobuf/proto" multierror "github.com/hashicorp/go-multierror" ) -func contains(a []string, v string) bool { - for _, i := range a { - if i == v { - return true - } - } - return false -} - // GetTimestamp returns the update timestamp func (kgdb *KVInterfaceGDB) GetTimestamp() string { return kgdb.kvg.ts.Get(kgdb.graph) @@ -199,6 +191,10 @@ func (kgdb *KVInterfaceGDB) DelEdge(eid string) error { if err := kgdb.kvg.kv.Delete(dkey); err != nil { return err } + if err := kgdb.kvg.idx.RemoveDoc(eid); err != nil { + return err + } + kgdb.kvg.ts.Touch(kgdb.graph) return nil } @@ -235,6 +231,10 @@ func (kgdb *KVInterfaceGDB) DelVertex(id string) error { if err := tx.Delete(vid); err != nil { return err } + if err := kgdb.kvg.idx.RemoveDoc(kvindex.FieldKeyParse(vid)); err != nil { + return err + } + for _, k := range delKeys { if err := tx.Delete(k); err != nil { return err @@ -378,7 +378,7 @@ func (kgdb *KVInterfaceGDB) GetOutChannel(ctx context.Context, reqChan chan gdbi for it.Seek(skeyPrefix); it.Valid() && bytes.HasPrefix(it.Key(), skeyPrefix); it.Next() { keyValue := it.Key() _, _, dst, _, label, etype := SrcEdgeKeyParse(keyValue) - if len(edgeLabels) == 0 || contains(edgeLabels, label) { + if len(edgeLabels) == 0 || setcmp.ContainsString(edgeLabels, label) { vkey := VertexKey(kgdb.graph, dst) if etype == edgeSingle { vertexChan <- elementData{ @@ -456,7 +456,7 @@ func (kgdb *KVInterfaceGDB) GetInChannel(ctx context.Context, reqChan chan gdbi. for it.Seek(dkeyPrefix); it.Valid() && bytes.HasPrefix(it.Key(), dkeyPrefix); it.Next() { keyValue := it.Key() _, src, _, _, label, _ := DstEdgeKeyParse(keyValue) - if len(edgeLabels) == 0 || contains(edgeLabels, label) { + if len(edgeLabels) == 0 || setcmp.ContainsString(edgeLabels, label) { vkey := VertexKey(kgdb.graph, src) dataValue, err := it.Get(vkey) if err == nil { @@ -506,7 +506,7 @@ func (kgdb *KVInterfaceGDB) GetOutEdgeChannel(ctx context.Context, reqChan chan for it.Seek(skeyPrefix); it.Valid() && bytes.HasPrefix(it.Key(), skeyPrefix); it.Next() { keyValue := it.Key() _, src, dst, eid, label, edgeType := SrcEdgeKeyParse(keyValue) - if len(edgeLabels) == 0 || contains(edgeLabels, label) { + if len(edgeLabels) == 0 || setcmp.ContainsString(edgeLabels, label) { if edgeType == edgeSingle { e := gdbi.Edge{} if load { @@ -563,7 +563,7 @@ func (kgdb *KVInterfaceGDB) GetInEdgeChannel(ctx context.Context, reqChan chan g for it.Seek(dkeyPrefix); it.Valid() && bytes.HasPrefix(it.Key(), dkeyPrefix); it.Next() { keyValue := it.Key() _, src, dst, eid, label, edgeType := DstEdgeKeyParse(keyValue) - if len(edgeLabels) == 0 || contains(edgeLabels, label) { + if len(edgeLabels) == 0 || setcmp.ContainsString(edgeLabels, label) { if edgeType == edgeSingle { e := gdbi.Edge{} if load { @@ -678,6 +678,51 @@ func (kgdb *KVInterfaceGDB) GetVertexList(ctx context.Context, loadProp bool) <- return o } +func (kgdb *KVInterfaceGDB) DeleteAllData(ctx context.Context, graph string) error { + go func() { + kgdb.kvg.kv.View(func(it kvi.KVIterator) error { + ePrefix := EdgeListPrefix(graph) + for it.Seek(ePrefix); it.Valid() && bytes.HasPrefix(it.Key(), ePrefix); it.Next() { + select { + case <-ctx.Done(): + return nil + default: + } + keyValue := it.Key() + _, eid, _, _, _, etype := EdgeKeyParse(keyValue) + if etype == edgeSingle { + kgdb.DelEdge(string(eid)) + } + } + return nil + }) + }() + + go func() { + kgdb.kvg.kv.View(func(it kvi.KVIterator) error { + vPrefix := VertexListPrefix(graph) + + for it.Seek(vPrefix); it.Valid() && bytes.HasPrefix(it.Key(), vPrefix); it.Next() { + select { + case <-ctx.Done(): + return nil + default: + } + gv := &gripql.Vertex{} + dataValue, _ := it.Value() + proto.Unmarshal(dataValue, gv) + keyValue := it.Key() + _, vid := VertexKeyParse(keyValue) + _ = kgdb.DelVertex(vid) + + } + return nil + }) + }() + + return nil +} + // ListVertexLabels returns a list of vertex types in the graph func (kgdb *KVInterfaceGDB) ListVertexLabels() ([]string, error) { labelField := fmt.Sprintf("%s.v.label", kgdb.graph) diff --git a/kvgraph/graphdb.go b/kvgraph/graphdb.go index ea261874..5dc165c3 100644 --- a/kvgraph/graphdb.go +++ b/kvgraph/graphdb.go @@ -2,6 +2,7 @@ package kvgraph import ( "bytes" + "context" "fmt" "github.com/bmeg/grip/gdbi" @@ -43,6 +44,9 @@ func (kgraph *KVGraph) DeleteGraph(graph string) error { graphKey := GraphKey(graph) kgraph.kv.Delete(graphKey) + kvgdb := KVInterfaceGDB{kvg: kgraph, graph: graph} + kvgdb.DeleteAllData(context.Background(), graph) + kgraph.deleteGraphIndex(graph) return nil diff --git a/kvgraph/test/index_test.go b/kvgraph/test/index_test.go index 0062f0e8..92d2a4aa 100644 --- a/kvgraph/test/index_test.go +++ b/kvgraph/test/index_test.go @@ -7,6 +7,7 @@ import ( "context" "github.com/bmeg/grip/kvindex" + "github.com/bmeg/grip/util/setcmp" ) var docs = `[ @@ -61,7 +62,7 @@ func TestFieldListing(t *testing.T) { count := 0 for _, field := range idx.ListFields() { - if !contains(newFields, field) { + if !setcmp.ContainsString(newFields, field) { t.Errorf("Bad field return: %s", field) } count++ @@ -89,7 +90,7 @@ func TestLoadDoc(t *testing.T) { count := 0 for d := range idx.GetTermMatch(context.Background(), "v.label", "Person", -1) { - if !contains(personDocs, d) { + if !setcmp.ContainsString(personDocs, d) { t.Errorf("Bad doc return: %s", d) } count++ @@ -100,7 +101,7 @@ func TestLoadDoc(t *testing.T) { count = 0 for d := range idx.GetTermMatch(context.Background(), "v.data.firstName", "Bob", -1) { - if !contains(bobDocs, d) { + if !setcmp.ContainsString(bobDocs, d) { t.Errorf("Bad doc return: %s", d) } count++ @@ -128,7 +129,7 @@ func TestTermEnum(t *testing.T) { count := 0 for d := range idx.FieldTerms("v.data.lastName") { count++ - if !contains(lastNames, d.(string)) { + if !setcmp.ContainsString(lastNames, d.(string)) { t.Errorf("Bad term return: %s", d) } } @@ -139,7 +140,7 @@ func TestTermEnum(t *testing.T) { count = 0 for d := range idx.FieldTerms("v.data.firstName") { count++ - if !contains(firstNames, d.(string)) { + if !setcmp.ContainsString(firstNames, d.(string)) { t.Errorf("Bad term return: %s", d) } } @@ -166,7 +167,7 @@ func TestTermCount(t *testing.T) { count := 0 for d := range idx.FieldStringTermCounts("v.data.lastName") { count++ - if !contains(lastNames, d.String) { + if !setcmp.ContainsString(lastNames, d.String) { t.Errorf("Bad term return: %s", d.String) } if d.String == "Smith" { @@ -182,7 +183,7 @@ func TestTermCount(t *testing.T) { count = 0 for d := range idx.FieldTermCounts("v.data.firstName") { count++ - if !contains(firstNames, d.String) { + if !setcmp.ContainsString(firstNames, d.String) { t.Errorf("Bad term return: %s", d.String) } } @@ -212,7 +213,7 @@ func TestDocDelete(t *testing.T) { count := 0 for d := range idx.FieldStringTermCounts("v.data.lastName") { count++ - if !contains(lastNames, d.String) { + if !setcmp.ContainsString(lastNames, d.String) { t.Errorf("Bad term return: %s", d.String) } if d.String == "Smith" { @@ -233,7 +234,7 @@ func TestDocDelete(t *testing.T) { count = 0 for d := range idx.FieldStringTermCounts("v.data.lastName") { count++ - if !contains(lastNames, d.String) { + if !setcmp.ContainsString(lastNames, d.String) { t.Errorf("Bad term return: %s", d.String) } if d.String == "Smith" { diff --git a/kvgraph/test/main_test.go b/kvgraph/test/main_test.go index 3cd61715..01f56fda 100644 --- a/kvgraph/test/main_test.go +++ b/kvgraph/test/main_test.go @@ -28,15 +28,6 @@ func resetKVInterface() { } } -func contains(a []string, v string) bool { - for _, i := range a { - if i == v { - return true - } - } - return false -} - func TestMain(m *testing.M) { var err error var exit = 1 diff --git a/kvindex/entries.go b/kvindex/entries.go index 5859f185..4636c332 100644 --- a/kvindex/entries.go +++ b/kvindex/entries.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "fmt" "math" + + "github.com/bmeg/grip/util/setcmp" ) // TermType defines in a term is a Number or a String @@ -37,18 +39,19 @@ func fieldScan(docID string, doc map[string]interface{}, fieldPrefix string, fie if containsPrefix(f, fields) { if x, ok := v.(map[string]interface{}); ok { fieldScan(docID, x, fmt.Sprintf("%s.%s", fieldPrefix, k), fields, out) - } else if contains(f, fields) { + } else if setcmp.ContainsString(fields, f) { out <- newEntry(docID, f, v) } } } } -func mapDig(i map[string]interface{}, path []string) interface{} { +// Given a list of fields (graphs), return term (label) of doc (graph element) if it exists on field +func getTermOnField(i map[string]interface{}, path []string) interface{} { if x, ok := i[path[0]]; ok { if len(path) > 1 { if y, ok := x.(map[string]interface{}); ok { - return mapDig(y, path[1:]) + return getTermOnField(y, path[1:]) } } else { return x diff --git a/kvindex/keys.go b/kvindex/keys.go index a0dc5942..a9155735 100644 --- a/kvindex/keys.go +++ b/kvindex/keys.go @@ -6,21 +6,25 @@ import ( // Fields // key: f | field +// Known as 'graph' in grip // val: var idxFieldPrefix = []byte("f") // Terms -// key: t | field | TermType | term +// key: t | field | TermType' +// Known as 'label' in grip // val: count var idxTermPrefix = []byte("t") // Entries // key: i | field | TermType | term | docid +// What links the graph + label to the doc via an id // val: var idxEntryPrefix = []byte("i") // Docs // key: d | docid +// Known as 'data' in grip // val: Doc entry list var idxDocPrefix = []byte("D") @@ -75,7 +79,7 @@ func EntryPrefix(field string) []byte { return bytes.Join([][]byte{idxEntryPrefix, []byte(field), {}}, []byte{0}) } -// EntryTypePrefix get prefix for all entries for a single field +// EntryTypePrefix get prefix for all entries for a single field an a single type func EntryTypePrefix(field string, ttype TermType) []byte { return bytes.Join([][]byte{idxEntryPrefix, []byte(field), {byte(ttype)}, {}}, []byte{0}) } diff --git a/kvindex/kvindex.go b/kvindex/kvindex.go index 35762f49..6bcdf63d 100644 --- a/kvindex/kvindex.go +++ b/kvindex/kvindex.go @@ -16,15 +16,6 @@ import ( const bufferSize = 1000 -func contains(c string, s []string) bool { - for _, i := range s { - if c == i { - return true - } - } - return false -} - func containsPrefix(c string, s []string) bool { for _, i := range s { if strings.HasPrefix(i, c) { @@ -64,8 +55,12 @@ func (idx *KVIndex) RemoveField(path string) error { fk := FieldKey(path) fkt := TermPrefix(path) ed := EntryPrefix(path) + dk := DocKey(path) + idx.KV.DeletePrefix(fkt) idx.KV.DeletePrefix(ed) + idx.KV.DeletePrefix(dk) + delete(idx.Fields, path) return idx.KV.Delete(fk) } @@ -101,19 +96,19 @@ func (idx *KVIndex) AddDocTx(tx kvi.KVBulkWrite, docID string, doc map[string]in docKey := DocKey(docID) for field, p := range idx.Fields { - x := mapDig(doc, p) - if x != nil { - term, t := GetTermBytes(x) + term := getTermOnField(doc, p) + if term != nil { + termBytes, t := GetTermBytes(term) switch t { case TermString, TermNumber: - entryKey := EntryKey(field, t, term, docID) + entryKey := EntryKey(field, t, termBytes, docID) err := tx.Set(entryKey, []byte{}) if err != nil { return fmt.Errorf("failed to set entry key %s: %v", entryKey, err) } sdoc.Entries = append(sdoc.Entries, entryKey) - termKey := TermKey(field, t, term) + termKey := TermKey(field, t, termBytes) //set the term count to 0 to invalidate it. Later on, if other code trying //to get the term count will have to recount //previously, it was a set(get+1), but for bulk loading, its better From 09004a723c676e744622dee6e4d9eb02d3b7c3c8 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 6 Feb 2025 09:01:51 -0800 Subject: [PATCH 2/7] swap tests to pebble, fix pebble graph drop command --- .github/workflows/tests.yml | 2 +- config/config.go | 4 ++ kvgraph/graph.go | 53 ++----------------- kvgraph/graphdb.go | 4 -- kvi/pebbledb/pebble_store.go | 23 +------- test/main_test.go | 36 ++++++++++++- test/{badger-auth.yml => pebble-auth.yml} | 6 +-- ...r-proxy-auth.yml => pebble-proxy-auth.yml} | 6 +-- 8 files changed, 53 insertions(+), 81 deletions(-) rename test/{badger-auth.yml => pebble-auth.yml} (85%) rename test/{badger-proxy-auth.yml => pebble-proxy-auth.yml} (80%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 226630d4..64fa0d07 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,7 +42,7 @@ jobs: uses: actions/checkout@v4 - name: run unit tests run: | - go test ./test/... -config badger.yml + go test ./test/... -config pebble.yml badgerTest: diff --git a/config/config.go b/config/config.go index db684f29..4ab86693 100644 --- a/config/config.go +++ b/config/config.go @@ -129,6 +129,10 @@ func TestifyConfig(c *Config) { a := "grip.db." + rand d.Badger = &a } + if d.Pebble != nil { + a := "grip.db." + rand + d.Pebble = &a + } if d.MongoDB != nil { d.MongoDB.DBName = "gripdb-" + rand } diff --git a/kvgraph/graph.go b/kvgraph/graph.go index 422028fd..e18a89bd 100644 --- a/kvgraph/graph.go +++ b/kvgraph/graph.go @@ -231,15 +231,17 @@ func (kgdb *KVInterfaceGDB) DelVertex(id string) error { if err := tx.Delete(vid); err != nil { return err } - if err := kgdb.kvg.idx.RemoveDoc(kvindex.FieldKeyParse(vid)); err != nil { - return err - } for _, k := range delKeys { if err := tx.Delete(k); err != nil { return err } } + + if err := kgdb.kvg.idx.RemoveDoc(kvindex.FieldKeyParse(vid)); err != nil { + return err + } + kgdb.kvg.ts.Touch(kgdb.graph) return nil }) @@ -678,51 +680,6 @@ func (kgdb *KVInterfaceGDB) GetVertexList(ctx context.Context, loadProp bool) <- return o } -func (kgdb *KVInterfaceGDB) DeleteAllData(ctx context.Context, graph string) error { - go func() { - kgdb.kvg.kv.View(func(it kvi.KVIterator) error { - ePrefix := EdgeListPrefix(graph) - for it.Seek(ePrefix); it.Valid() && bytes.HasPrefix(it.Key(), ePrefix); it.Next() { - select { - case <-ctx.Done(): - return nil - default: - } - keyValue := it.Key() - _, eid, _, _, _, etype := EdgeKeyParse(keyValue) - if etype == edgeSingle { - kgdb.DelEdge(string(eid)) - } - } - return nil - }) - }() - - go func() { - kgdb.kvg.kv.View(func(it kvi.KVIterator) error { - vPrefix := VertexListPrefix(graph) - - for it.Seek(vPrefix); it.Valid() && bytes.HasPrefix(it.Key(), vPrefix); it.Next() { - select { - case <-ctx.Done(): - return nil - default: - } - gv := &gripql.Vertex{} - dataValue, _ := it.Value() - proto.Unmarshal(dataValue, gv) - keyValue := it.Key() - _, vid := VertexKeyParse(keyValue) - _ = kgdb.DelVertex(vid) - - } - return nil - }) - }() - - return nil -} - // ListVertexLabels returns a list of vertex types in the graph func (kgdb *KVInterfaceGDB) ListVertexLabels() ([]string, error) { labelField := fmt.Sprintf("%s.v.label", kgdb.graph) diff --git a/kvgraph/graphdb.go b/kvgraph/graphdb.go index 5dc165c3..ea261874 100644 --- a/kvgraph/graphdb.go +++ b/kvgraph/graphdb.go @@ -2,7 +2,6 @@ package kvgraph import ( "bytes" - "context" "fmt" "github.com/bmeg/grip/gdbi" @@ -44,9 +43,6 @@ func (kgraph *KVGraph) DeleteGraph(graph string) error { graphKey := GraphKey(graph) kgraph.kv.Delete(graphKey) - kvgdb := KVInterfaceGDB{kvg: kgraph, graph: graph} - kvgdb.DeleteAllData(context.Background(), graph) - kgraph.deleteGraphIndex(graph) return nil diff --git a/kvi/pebbledb/pebble_store.go b/kvi/pebbledb/pebble_store.go index b215183b..d239c741 100644 --- a/kvi/pebbledb/pebble_store.go +++ b/kvi/pebbledb/pebble_store.go @@ -69,27 +69,8 @@ func (pdb *PebbleKV) Delete(id []byte) error { // DeletePrefix deletes all elements in kvstore that begin with prefix `id` func (pdb *PebbleKV) DeletePrefix(prefix []byte) error { - deleteBlockSize := 10000 - for found := true; found; { - found = false - wb := make([][]byte, 0, deleteBlockSize) - it, err := pdb.db.NewIter(&pebble.IterOptions{LowerBound: prefix}) - if err != nil { - return err - } - for ; it.Valid() && bytes.HasPrefix(it.Key(), prefix) && len(wb) < deleteBlockSize-1; it.Next() { - wb = append(wb, copyBytes(it.Key())) - } - it.Close() - for _, i := range wb { - err := pdb.db.Delete(i, nil) - if err != nil { - return err - } - found = true - } - } - return nil + nextPrefix := append(prefix, 0xFF) + return pdb.db.DeleteRange(prefix, nextPrefix, nil) } // HasKey returns true if the key is exists in kvstore diff --git a/test/main_test.go b/test/main_test.go index 36a71899..e1673a0b 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -17,6 +17,8 @@ import ( _ "github.com/bmeg/grip/kvi/badgerdb" // import so badger will register itself _ "github.com/bmeg/grip/kvi/boltdb" // import so bolt will register itself _ "github.com/bmeg/grip/kvi/leveldb" // import so level will register itself + _ "github.com/bmeg/grip/kvi/pebbledb" // import so pebble will register itself + "github.com/bmeg/grip/mongo" "github.com/bmeg/grip/psql" "github.com/bmeg/grip/util" @@ -93,7 +95,7 @@ func TestMain(m *testing.M) { return } } else { - conf.AddBadgerDefault() + conf.AddPebbleDefault() } config.TestifyConfig(conf) @@ -116,6 +118,11 @@ func TestMain(m *testing.M) { defer func() { os.RemoveAll(*dbconfig.Badger) }() + } else if dbconfig.Pebble != nil { + gdb, err = kvgraph.NewKVGraphDB("pebble", *dbconfig.Pebble) + defer func() { + os.RemoveAll(*dbconfig.Pebble) + }() } else if dbconfig.Bolt != nil { gdb, err = kvgraph.NewKVGraphDB("bolt", *dbconfig.Bolt) defer func() { @@ -159,6 +166,33 @@ func TestMain(m *testing.M) { } } + // After deleting graph, docs, entries, fields should no longer exist in doc + err = gdb.DeleteGraph("test-graph") + err = gdb.AddGraph("test-graph") + if err != nil { + fmt.Println("Error: failed to add graph:", err) + return + } + db, err = gdb.Graph("test-graph") + if err != nil { + fmt.Println("Error: failed to connect to graph:", err) + return + } + + afterVertexLabels, _ := db.ListVertexLabels() + afterEdgeLabels, _ := db.ListEdgeLabels() + fmt.Printf("afterEdgeLabels: %s afterVertexLabels: %s\n", afterEdgeLabels, afterVertexLabels) + if len(afterVertexLabels) != 0 || len(afterEdgeLabels) != 0 { + panic(fmt.Errorf("afterEdgeLabels: %s or afterVertexLabels: %s are not empty\n", afterEdgeLabels, afterVertexLabels)) + } + + if dbname != "existing-sql" { + err = setupGraph() + if err != nil { + fmt.Println("Error: setting up graph:", err) + return + } + } // run tests exit = m.Run() } diff --git a/test/badger-auth.yml b/test/pebble-auth.yml similarity index 85% rename from test/badger-auth.yml rename to test/pebble-auth.yml index 560931e1..f7c9757b 100644 --- a/test/badger-auth.yml +++ b/test/pebble-auth.yml @@ -1,8 +1,8 @@ -Default: badger +Default: pebble Drivers: - badger: - Badger: grip-badger.db + pebble: + Pebble: grip-pebble.db Server: Accounts: diff --git a/test/badger-proxy-auth.yml b/test/pebble-proxy-auth.yml similarity index 80% rename from test/badger-proxy-auth.yml rename to test/pebble-proxy-auth.yml index 5862c84b..8b8c0403 100644 --- a/test/badger-proxy-auth.yml +++ b/test/pebble-proxy-auth.yml @@ -1,8 +1,8 @@ -Default: badger +Default: pebble Drivers: - badger: - Badger: grip-badger.db + pebble: + Pebble: grip-pebble.db Server: Accounts: From a9b9bd3b66fc1d1f0058d93f0927e553945e617e Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 6 Feb 2025 09:06:26 -0800 Subject: [PATCH 3/7] update auth test to pebble --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 64fa0d07..0aa98fc7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -217,7 +217,7 @@ jobs: run: | # start grip server chmod +x grip - ./grip server --rpc-port 18202 --http-port 18201 --config ./test/badger-auth.yml & + ./grip server --rpc-port 18202 --http-port 18201 --config ./test/pebble-auth.yml & sleep 5 # simple auth # run tests without credentials, should fail @@ -228,7 +228,7 @@ jobs: echo "Got expected auth error" fi # run specialized role based tests - make test-authorization ARGS="--grip_config_file_path test/badger-auth.yml" + make test-authorization ARGS="--grip_config_file_path test/pebble-auth.yml" From c2adf0ef912c15152e87bfcc003f686e35949968 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 20 Feb 2025 15:11:58 -0800 Subject: [PATCH 4/7] address oom issues --- server/api.go | 37 ++++++++++++++++++++----------------- server/server.go | 40 +++++++++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/server/api.go b/server/api.go index 1ecc8e33..38bf0124 100644 --- a/server/api.go +++ b/server/api.go @@ -229,7 +229,7 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error var populated bool var sch *gripql.Graph out := &graph.GraphSchema{Classes: map[string]*jsonschema.Schema{}, Compiler: nil} - elementStream := make(chan *gdbi.GraphElement) + elementStream := make(chan *gdbi.GraphElement, 100) var retErrs []string for { var err error @@ -271,6 +271,11 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error wg.Add(1) go func() { defer wg.Done() + defer func() { + for ge := range elementStream { + server.streamPool.Put(ge) + } + }() err := graph.BulkAdd(elementStream) if err != nil { log.WithFields(log.Fields{"graph": class.Graph, "error": err}).Error("BulkAddRaw: error") @@ -298,27 +303,24 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error } for _, element := range result { + ge := server.streamPool.Get().(*gdbi.GraphElement) + ge.Graph = class.Graph if element.Vertex != nil { - elementStream <- &gdbi.GraphElement{ - Vertex: &gdbi.Vertex{ - ID: element.Vertex.Gid, - Data: element.Vertex.Data.AsMap(), - Label: element.Vertex.Label, - }, - Graph: class.Graph, + ge.Vertex = &gdbi.Vertex{ + ID: element.Vertex.Gid, + Data: element.Vertex.Data.AsMap(), + Label: element.Vertex.Label, } } else { - elementStream <- &gdbi.GraphElement{ - Edge: &gdbi.Edge{ - ID: element.Edge.Gid, - Label: element.Edge.Label, - From: element.Edge.From, - To: element.Edge.To, - Data: element.Edge.Data.AsMap(), - }, - Graph: class.Graph, + ge.Edge = &gdbi.Edge{ + ID: element.Edge.Gid, + Label: element.Edge.Label, + From: element.Edge.From, + To: element.Edge.To, + Data: element.Edge.Data.AsMap(), } } + elementStream <- ge insertCount++ } @@ -410,6 +412,7 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { } else { insertCount++ elementStream <- gdbi.NewGraphElement(element) + } } } diff --git a/server/server.go b/server/server.go index 5deb2942..04662daf 100644 --- a/server/server.go +++ b/server/server.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/bmeg/grip/config" @@ -42,15 +43,16 @@ type GripServer struct { gripql.UnimplementedEditServer gripql.UnimplementedJobServer gripql.UnimplementedConfigureServer - dbs map[string]gdbi.GraphDB //graph database drivers - graphMap map[string]string //mapping from graph name to graph database driver - conf *config.Config //global configuration - schemas map[string]*gripql.Graph //cached schemas - mappings map[string]*gripql.Graph //cached gripper graph mappings - plugins map[string]*Plugin - sources map[string]gripper.GRIPSourceClient - baseDir string - jStorage jobstorage.JobStorage + dbs map[string]gdbi.GraphDB //graph database drivers + graphMap map[string]string //mapping from graph name to graph database driver + conf *config.Config //global configuration + schemas map[string]*gripql.Graph //cached schemas + mappings map[string]*gripql.Graph //cached gripper graph mappings + plugins map[string]*Plugin + sources map[string]gripper.GRIPSourceClient + baseDir string + jStorage jobstorage.JobStorage + streamPool *sync.Pool } // NewGripServer initializes a GRPC server to connect to the graph store @@ -92,13 +94,21 @@ func NewGripServer(conf *config.Config, baseDir string, drivers map[string]gdbi. } } + // Add an element pool for managing resources when streaming data + var graphElementPool = &sync.Pool{ + New: func() interface{} { + return &gdbi.GraphElement{} + }, + } + server := &GripServer{ - dbs: gdbs, - conf: conf, - schemas: schemas, - mappings: map[string]*gripql.Graph{}, - plugins: map[string]*Plugin{}, - sources: sources, + dbs: gdbs, + conf: conf, + schemas: schemas, + mappings: map[string]*gripql.Graph{}, + plugins: map[string]*Plugin{}, + sources: sources, + streamPool: graphElementPool, } if conf.Default == "" { From 1415e5997112d17d3525ef825e3920655d270bc7 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Fri, 21 Feb 2025 09:33:44 -0800 Subject: [PATCH 5/7] change stream func to pass chan instead of list --- mongo/graph.go | 57 ++++++++++++++++++++- psql/graph.go | 129 +++++++++++++++++++++++++++++++++++++++++++++++- server/api.go | 4 ++ sqlite/graph.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++- util/insert.go | 101 +++++++++++++++++++------------------ 5 files changed, 368 insertions(+), 50 deletions(-) diff --git a/mongo/graph.go b/mongo/graph.go index aec91b8b..c915ace6 100644 --- a/mongo/graph.go +++ b/mongo/graph.go @@ -116,8 +116,63 @@ func (mg *Graph) AddEdge(edges []*gdbi.Edge) error { return err } +func (mg *Graph) StreamEdges(edgeChan <-chan *gdbi.Edge, batchsize int) error { + eCol := mg.ar.EdgeCollection(mg.graph) + var err error + docBatch := make([]mongo.WriteModel, 0, batchsize) + + for edge := range edgeChan { + i := mongo.NewReplaceOneModel().SetUpsert(true).SetFilter(bson.M{FIELD_ID: edge.ID}) + ent := PackEdge(edge) + i.SetReplacement(ent) + docBatch = append(docBatch, i) + + if len(docBatch) >= batchsize { + _, err = eCol.BulkWrite(context.Background(), docBatch) + if err != nil { + log.Errorf("StreamEdges error: (%s) %s", docBatch, err) + } + docBatch = docBatch[:0] + } + } + if len(docBatch) > 0 { + _, err = eCol.BulkWrite(context.Background(), docBatch) + if err != nil { + log.Errorf("StreamEdges error: (%s) %s", docBatch, err) + } + } + return err +} + +func (mg *Graph) StreamVertices(vertChan <-chan *gdbi.Vertex, batchsize int) error { + vCol := mg.ar.VertexCollection(mg.graph) + var err error + docBatch := make([]mongo.WriteModel, 0, batchsize) + for v := range vertChan { + i := mongo.NewReplaceOneModel().SetUpsert(true).SetFilter(bson.M{FIELD_ID: v.ID}) + ent := PackVertex(v) + i.SetReplacement(ent) + docBatch = append(docBatch, i) + + if len(docBatch) >= batchsize { + _, err = vCol.BulkWrite(context.Background(), docBatch) + if err != nil { + log.Errorf("StreamVertices error: (%s) %s", docBatch, err) + } + docBatch = docBatch[:0] + } + } + if len(docBatch) > 0 { + _, err = vCol.BulkWrite(context.Background(), docBatch) + if err != nil { + log.Errorf("StreamVertices error: (%s) %s", docBatch, err) + } + } + return err +} + func (mg *Graph) BulkAdd(stream <-chan *gdbi.GraphElement) error { - return util.StreamBatch(stream, 50, mg.graph, mg.AddVertex, mg.AddEdge) + return util.StreamBatch(stream, 100, mg.graph, mg.StreamVertices, mg.StreamEdges) } func (mg *Graph) BulkDel(Data *gdbi.DeleteData) error { diff --git a/psql/graph.go b/psql/graph.go index 5c841355..7f0db24c 100644 --- a/psql/graph.go +++ b/psql/graph.go @@ -79,6 +79,133 @@ func (g *Graph) AddVertex(vertices []*gdbi.Vertex) error { return nil } +// AddVertex adds a vertex to the database +func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error { + txn, err := g.db.Begin() + if err != nil { + return fmt.Errorf("StreamVertices: Begin Txn: %v", err) + } + + s := fmt.Sprintf( + `INSERT INTO %s (gid, label, data) VALUES ($1, $2, $3) + ON CONFLICT (gid) DO UPDATE SET + gid = excluded.gid, + label = excluded.label, + data = excluded.data;`, + g.v, + ) + stmt, err := txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamVertices: Prepare Stmt: %v", err) + } + + count := 0 + for v := range vertices { + js, err := json.Marshal(v.Data) + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) + } + _, err = stmt.Exec(v.ID, v.Label, js) + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) + } + count++ + + if count%1000 == 0 { + if err := txn.Commit(); err != nil { + _ = stmt.Close() + return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) + } + + txn, err = g.db.Begin() + if err != nil { + return fmt.Errorf("StreamVertices: Begin New Txn: %v", err) + } + stmt, err = txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamVertices: Prepare New Stmt: %v", err) + } + } + + } + + err = stmt.Close() + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Close: %v", err) + } + + err = txn.Commit() + if err != nil { + return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) + } + + return nil +} + +// AddEdge adds an edge to the database +func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { + txn, err := g.db.Begin() + if err != nil { + return fmt.Errorf("StreamEdges: Begin Txn: %v", err) + } + + s := fmt.Sprintf( + `INSERT INTO %s (gid, label, "from", "to", data) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (gid) DO UPDATE SET + gid = excluded.gid, + label = excluded.label, + "from" = excluded.from, + "to" = excluded.to, + data = excluded.data;`, + g.e, + ) + stmt, err := txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamEdges: Prepare Stmt: %v", err) + } + + count := 0 + for e := range edges { + js, err := json.Marshal(e.Data) + if err != nil { + return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) + } + _, err = stmt.Exec(e.ID, e.Label, e.From, e.To, js) + if err != nil { + return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) + } + count++ + if count%1000 == 0 { + if err := txn.Commit(); err != nil { + _ = stmt.Close() + return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) + } + + txn, err = g.db.Begin() + if err != nil { + return fmt.Errorf("StreamEdges: Begin New Txn: %v", err) + } + stmt, err = txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamEdges: Prepare New Stmt: %v", err) + } + } + + } + + err = stmt.Close() + if err != nil { + return fmt.Errorf("StreamEdges: Stmt.Close: %v", err) + } + + err = txn.Commit() + if err != nil { + return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) + } + + return nil +} + // AddEdge adds an edge to the database func (g *Graph) AddEdge(edges []*gdbi.Edge) error { txn, err := g.db.Begin() @@ -126,7 +253,7 @@ func (g *Graph) AddEdge(edges []*gdbi.Edge) error { } func (g *Graph) BulkAdd(stream <-chan *gdbi.GraphElement) error { - return util.StreamBatch(stream, 50, g.graph, g.AddVertex, g.AddEdge) + return util.StreamBatch(stream, 50, g.graph, g.StreamVertices, g.StreamEdges) } func (g *Graph) BulkDel(Data *gdbi.DeleteData) error { diff --git a/server/api.go b/server/api.go index 38bf0124..952a2ce6 100644 --- a/server/api.go +++ b/server/api.go @@ -231,6 +231,7 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error out := &graph.GraphSchema{Classes: map[string]*jsonschema.Schema{}, Compiler: nil} elementStream := make(chan *gdbi.GraphElement, 100) var retErrs []string + sem := make(chan struct{}, 1000) for { var err error class, err := stream.Recv() @@ -320,13 +321,16 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error Data: element.Edge.Data.AsMap(), } } + sem <- struct{}{} elementStream <- ge insertCount++ + <-sem } } close(elementStream) wg.Wait() + close(sem) return stream.SendAndClose(&gripql.BulkJsonEditResult{InsertCount: insertCount, Errors: retErrs}) } diff --git a/sqlite/graph.go b/sqlite/graph.go index 7281f25e..9b780606 100644 --- a/sqlite/graph.go +++ b/sqlite/graph.go @@ -77,6 +77,131 @@ func (g *Graph) AddVertex(vertices []*gdbi.Vertex) error { return nil } +func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error { + txn, err := g.db.Begin() + if err != nil { + return fmt.Errorf("StreamVertices: Begin Txn: %v", err) + } + + s := fmt.Sprintf( + `INSERT INTO %s (gid, label, data) VALUES ($1, $2, $3) + ON CONFLICT (gid) DO UPDATE SET + gid = excluded.gid, + label = excluded.label, + data = excluded.data;`, + g.v, + ) + stmt, err := txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamVertices: Prepare Stmt: %v", err) + } + + count := 0 + for v := range vertices { + js, err := json.Marshal(v.Data) + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) + } + _, err = stmt.Exec(v.ID, v.Label, js) + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) + } + count++ + + if count%1000 == 0 { + if err := txn.Commit(); err != nil { + _ = stmt.Close() + return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) + } + + txn, err = g.db.Begin() + if err != nil { + return fmt.Errorf("StreamVertices: Begin New Txn: %v", err) + } + stmt, err = txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamVertices: Prepare New Stmt: %v", err) + } + } + + } + + err = stmt.Close() + if err != nil { + return fmt.Errorf("StreamVertices: Stmt.Close: %v", err) + } + + err = txn.Commit() + if err != nil { + return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) + } + + return nil +} + +func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { + txn, err := g.db.Begin() + if err != nil { + return fmt.Errorf("StreamEdges: Begin Txn: %v", err) + } + + s := fmt.Sprintf( + `INSERT INTO %s (gid, label, "from", "to", data) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (gid) DO UPDATE SET + gid = excluded.gid, + label = excluded.label, + "from" = excluded."from", + "to" = excluded."to", + data = excluded.data;`, + g.e, + ) + stmt, err := txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamEdges: Prepare Stmt: %v", err) + } + + count := 0 + for e := range edges { + js, err := json.Marshal(e.Data) + if err != nil { + return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) + } + _, err = stmt.Exec(e.ID, e.Label, e.From, e.To, js) + if err != nil { + return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) + } + count++ + if count%1000 == 0 { + if err := txn.Commit(); err != nil { + _ = stmt.Close() + return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) + } + + txn, err = g.db.Begin() + if err != nil { + return fmt.Errorf("StreamEdges: Begin New Txn: %v", err) + } + stmt, err = txn.Prepare(s) + if err != nil { + return fmt.Errorf("StreamEdges: Prepare New Stmt: %v", err) + } + } + + } + + err = stmt.Close() + if err != nil { + return fmt.Errorf("StreamEdges: Stmt.Close: %v", err) + } + + err = txn.Commit() + if err != nil { + return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) + } + + return nil +} + func (g *Graph) AddEdge(edges []*gdbi.Edge) error { txn, err := g.db.Begin() if err != nil { @@ -161,7 +286,7 @@ func (g *Graph) GetEdge(gid string, load bool) *gdbi.Edge { } func (g *Graph) BulkAdd(stream <-chan *gdbi.GraphElement) error { - return util.StreamBatch(stream, 50, g.graph, g.AddVertex, g.AddEdge) + return util.StreamBatch(stream, 50, g.graph, g.StreamVertices, g.StreamEdges) } func (g *Graph) BulkDel(Data *gdbi.DeleteData) error { diff --git a/util/insert.go b/util/insert.go index 419df0ac..1566f645 100644 --- a/util/insert.go +++ b/util/insert.go @@ -11,39 +11,29 @@ import ( // StreamBatch a stream of inputs and loads them into the graph // This function assumes incoming stream is GraphElemnts from a single graph -func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, vertexAdd func([]*gdbi.Vertex) error, edgeAdd func([]*gdbi.Edge) error) error { - +func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, vertexAdd func(<-chan *gdbi.Vertex, int) error, edgeAdd func(<-chan *gdbi.Edge, int) error) error { var bulkErr *multierror.Error vertCount := 0 edgeCount := 0 - vertexBatchChan := make(chan []*gdbi.Vertex) - edgeBatchChan := make(chan []*gdbi.Edge) wg := &sync.WaitGroup{} - wg.Add(1) + vertexChan := make(chan *gdbi.Vertex, batchSize) + edgeChan := make(chan *gdbi.Edge, batchSize) + + // Start goroutines to process vertices and edges + wg.Add(2) go func() { - for vBatch := range vertexBatchChan { - if len(vBatch) > 0 { - err := vertexAdd(vBatch) - if err != nil { - bulkErr = multierror.Append(bulkErr, err) - } - } + defer wg.Done() + if err := vertexAdd(vertexChan, batchSize); err != nil { + bulkErr = multierror.Append(bulkErr, err) } - wg.Done() }() - wg.Add(1) go func() { - for eBatch := range edgeBatchChan { - if len(eBatch) > 0 { - err := edgeAdd(eBatch) - if err != nil { - bulkErr = multierror.Append(bulkErr, err) - } - } + defer wg.Done() + if err := edgeAdd(edgeChan, batchSize); err != nil { + bulkErr = multierror.Append(bulkErr, err) } - wg.Done() }() vertexBatch := make([]*gdbi.Vertex, 0, batchSize) @@ -55,45 +45,66 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, bulkErr, fmt.Errorf("unexpected graph reference: %s != %s", element.Graph, graph), ) - } else if element.Vertex != nil { - if len(vertexBatch) >= batchSize { - vertexBatchChan <- vertexBatch - vertexBatch = make([]*gdbi.Vertex, 0, batchSize) - } + continue + } + if element.Vertex != nil { vertex := element.Vertex - err := vertex.Validate() - if err != nil { + if err := vertex.Validate(); err != nil { bulkErr = multierror.Append( bulkErr, fmt.Errorf("vertex validation failed: %v", err), ) - } else { - vertexBatch = append(vertexBatch, vertex) - vertCount++ + continue } - } else if element.Edge != nil { - if len(edgeBatch) >= batchSize { - edgeBatchChan <- edgeBatch - edgeBatch = make([]*gdbi.Edge, 0, batchSize) + + vertexBatch = append(vertexBatch, vertex) + vertCount++ + + if len(vertexBatch) >= batchSize { + for _, v := range vertexBatch { + vertexChan <- v + } + vertexBatch = vertexBatch[:0] // Reset batch slice } + } else if element.Edge != nil { edge := element.Edge if edge.ID == "" { edge.ID = UUID() } - err := edge.Validate() - if err != nil { + + if err := edge.Validate(); err != nil { bulkErr = multierror.Append( bulkErr, fmt.Errorf("edge validation failed: %v", err), ) - } else { - edgeBatch = append(edgeBatch, edge) - edgeCount++ + continue + } + + edgeBatch = append(edgeBatch, edge) + edgeCount++ + + if len(edgeBatch) >= batchSize { + for _, e := range edgeBatch { + edgeChan <- e + } + edgeBatch = edgeBatch[:0] // Reset batch slice } } } - vertexBatchChan <- vertexBatch - edgeBatchChan <- edgeBatch + + // Send remaining vertices and edges in the batch + for _, v := range vertexBatch { + vertexChan <- v + } + for _, e := range edgeBatch { + edgeChan <- e + } + + // Close channels after all data is sent + close(vertexChan) + close(edgeChan) + + wg.Wait() if vertCount != 0 { log.Debugf("%d vertices streamed to BulkAdd", vertCount) @@ -103,9 +114,5 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, log.Debugf("%d edges streamed to BulkAdd", edgeCount) } - close(edgeBatchChan) - close(vertexBatchChan) - wg.Wait() - return bulkErr.ErrorOrNil() } From 3d80c14ed485dd165c71ee5a1fe706554513b4ce Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Fri, 21 Feb 2025 10:42:07 -0800 Subject: [PATCH 6/7] updates --- mongo/graph.go | 6 +++--- server/api.go | 27 +++++++++++++++++---------- util/insert.go | 14 +++++++++++--- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/mongo/graph.go b/mongo/graph.go index c915ace6..10a0da5f 100644 --- a/mongo/graph.go +++ b/mongo/graph.go @@ -132,7 +132,7 @@ func (mg *Graph) StreamEdges(edgeChan <-chan *gdbi.Edge, batchsize int) error { if err != nil { log.Errorf("StreamEdges error: (%s) %s", docBatch, err) } - docBatch = docBatch[:0] + docBatch = make([]mongo.WriteModel, 0, batchsize) } } if len(docBatch) > 0 { @@ -159,7 +159,7 @@ func (mg *Graph) StreamVertices(vertChan <-chan *gdbi.Vertex, batchsize int) err if err != nil { log.Errorf("StreamVertices error: (%s) %s", docBatch, err) } - docBatch = docBatch[:0] + docBatch = make([]mongo.WriteModel, 0, batchsize) } } if len(docBatch) > 0 { @@ -172,7 +172,7 @@ func (mg *Graph) StreamVertices(vertChan <-chan *gdbi.Vertex, batchsize int) err } func (mg *Graph) BulkAdd(stream <-chan *gdbi.GraphElement) error { - return util.StreamBatch(stream, 100, mg.graph, mg.StreamVertices, mg.StreamEdges) + return util.StreamBatch(stream, 50, mg.graph, mg.StreamVertices, mg.StreamEdges) } func (mg *Graph) BulkDel(Data *gdbi.DeleteData) error { diff --git a/server/api.go b/server/api.go index 952a2ce6..caac9843 100644 --- a/server/api.go +++ b/server/api.go @@ -229,9 +229,9 @@ func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error var populated bool var sch *gripql.Graph out := &graph.GraphSchema{Classes: map[string]*jsonschema.Schema{}, Compiler: nil} - elementStream := make(chan *gdbi.GraphElement, 100) + elementStream := make(chan *gdbi.GraphElement, 50) var retErrs []string - sem := make(chan struct{}, 1000) + sem := make(chan struct{}, 100) for { var err error class, err := stream.Recv() @@ -340,9 +340,16 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { var insertCount int32 var errorCount int32 - elementStream := make(chan *gdbi.GraphElement, 100) + elementStream := make(chan *gdbi.GraphElement, 50) wg := &sync.WaitGroup{} + defer func() { + if elementStream != nil { + close(elementStream) + } + wg.Wait() + }() + for { element, err := stream.Recv() if err == io.EOF { @@ -364,7 +371,10 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { // create a BulkAdd stream per graph // close and switch when a new graph is encountered if element.Graph != graphName { - close(elementStream) + if elementStream != nil { + close(elementStream) + wg.Wait() + } gdb, err := server.getGraphDB(element.Graph) if err != nil { errorCount++ @@ -379,10 +389,10 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { } graphName = element.Graph - elementStream = make(chan *gdbi.GraphElement, 100) + elementStream = make(chan *gdbi.GraphElement, 50) wg.Add(1) - go func() { + go func(graphName string, stream chan *gdbi.GraphElement) { log.WithFields(log.Fields{"graph": element.Graph}).Info("BulkAdd: streaming elements to graph") err := graph.BulkAdd(elementStream) if err != nil { @@ -391,7 +401,7 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { errorCount++ } wg.Done() - }() + }(graphName, elementStream) } if element.Vertex != nil { @@ -421,9 +431,6 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { } } - close(elementStream) - wg.Wait() - return stream.SendAndClose(&gripql.BulkEditResult{InsertCount: insertCount, ErrorCount: errorCount}) } diff --git a/util/insert.go b/util/insert.go index 1566f645..d37cf0c8 100644 --- a/util/insert.go +++ b/util/insert.go @@ -1,12 +1,14 @@ package util import ( + "context" "fmt" "sync" "github.com/bmeg/grip/gdbi" "github.com/bmeg/grip/log" multierror "github.com/hashicorp/go-multierror" + "golang.org/x/sync/semaphore" ) // StreamBatch a stream of inputs and loads them into the graph @@ -20,7 +22,8 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, vertexChan := make(chan *gdbi.Vertex, batchSize) edgeChan := make(chan *gdbi.Edge, batchSize) - // Start goroutines to process vertices and edges + sem := semaphore.NewWeighted(int64(batchSize * 2)) + wg.Add(2) go func() { defer wg.Done() @@ -57,6 +60,7 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, continue } + sem.Acquire(context.Background(), 1) vertexBatch = append(vertexBatch, vertex) vertCount++ @@ -64,7 +68,8 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, for _, v := range vertexBatch { vertexChan <- v } - vertexBatch = vertexBatch[:0] // Reset batch slice + vertexBatch = make([]*gdbi.Vertex, 0, batchSize) + sem.Release(int64(len(vertexBatch))) } } else if element.Edge != nil { edge := element.Edge @@ -80,6 +85,7 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, continue } + sem.Acquire(context.Background(), 1) edgeBatch = append(edgeBatch, edge) edgeCount++ @@ -87,7 +93,8 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, for _, e := range edgeBatch { edgeChan <- e } - edgeBatch = edgeBatch[:0] // Reset batch slice + edgeBatch = make([]*gdbi.Edge, 0, batchSize) + sem.Release(int64(len(edgeBatch))) } } } @@ -105,6 +112,7 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, close(edgeChan) wg.Wait() + sem.Release(int64(len(vertexBatch) + len(edgeBatch))) if vertCount != 0 { log.Debugf("%d vertices streamed to BulkAdd", vertCount) From 4a91bc794513166dfe47249ee5daa5d66debdc64 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Sun, 23 Feb 2025 16:19:04 -0800 Subject: [PATCH 7/7] Make bulk add raw conserve memory --- mongo/graph.go | 2 +- psql/graph.go | 43 +-------- server/api.go | 243 ++++++++++++++++++++++++++++-------------------- sqlite/graph.go | 35 ------- util/insert.go | 6 +- 5 files changed, 150 insertions(+), 179 deletions(-) diff --git a/mongo/graph.go b/mongo/graph.go index 10a0da5f..fb3909dc 100644 --- a/mongo/graph.go +++ b/mongo/graph.go @@ -172,7 +172,7 @@ func (mg *Graph) StreamVertices(vertChan <-chan *gdbi.Vertex, batchsize int) err } func (mg *Graph) BulkAdd(stream <-chan *gdbi.GraphElement) error { - return util.StreamBatch(stream, 50, mg.graph, mg.StreamVertices, mg.StreamEdges) + return util.StreamBatch(stream, 100, mg.graph, mg.StreamVertices, mg.StreamEdges) } func (mg *Graph) BulkDel(Data *gdbi.DeleteData) error { diff --git a/psql/graph.go b/psql/graph.go index 7f0db24c..d4ac2040 100644 --- a/psql/graph.go +++ b/psql/graph.go @@ -99,7 +99,6 @@ func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error return fmt.Errorf("StreamVertices: Prepare Stmt: %v", err) } - count := 0 for v := range vertices { js, err := json.Marshal(v.Data) if err != nil { @@ -109,24 +108,6 @@ func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error if err != nil { return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) } - count++ - - if count%1000 == 0 { - if err := txn.Commit(); err != nil { - _ = stmt.Close() - return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) - } - - txn, err = g.db.Begin() - if err != nil { - return fmt.Errorf("StreamVertices: Begin New Txn: %v", err) - } - stmt, err = txn.Prepare(s) - if err != nil { - return fmt.Errorf("StreamVertices: Prepare New Stmt: %v", err) - } - } - } err = stmt.Close() @@ -154,8 +135,8 @@ func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { ON CONFLICT (gid) DO UPDATE SET gid = excluded.gid, label = excluded.label, - "from" = excluded.from, - "to" = excluded.to, + "from" = excluded."from", + "to" = excluded."to", data = excluded.data;`, g.e, ) @@ -164,7 +145,6 @@ func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { return fmt.Errorf("StreamEdges: Prepare Stmt: %v", err) } - count := 0 for e := range edges { js, err := json.Marshal(e.Data) if err != nil { @@ -174,28 +154,11 @@ func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { if err != nil { return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) } - count++ - if count%1000 == 0 { - if err := txn.Commit(); err != nil { - _ = stmt.Close() - return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) - } - - txn, err = g.db.Begin() - if err != nil { - return fmt.Errorf("StreamEdges: Begin New Txn: %v", err) - } - stmt, err = txn.Prepare(s) - if err != nil { - return fmt.Errorf("StreamEdges: Prepare New Stmt: %v", err) - } - } - } err = stmt.Close() if err != nil { - return fmt.Errorf("StreamEdges: Stmt.Close: %v", err) + return fmt.Errorf("StreamVertices: Stmt.Close: %v", err) } err = txn.Commit() diff --git a/server/api.go b/server/api.go index caac9843..7c84bd63 100644 --- a/server/api.go +++ b/server/api.go @@ -224,131 +224,168 @@ func (server *GripServer) addEdge(ctx context.Context, elem *gripql.GraphElement } func (server *GripServer) BulkAddRaw(stream gripql.Edit_BulkAddRawServer) error { - var insertCount int32 - wg := &sync.WaitGroup{} - var populated bool - var sch *gripql.Graph - out := &graph.GraphSchema{Classes: map[string]*jsonschema.Schema{}, Compiler: nil} - elementStream := make(chan *gdbi.GraphElement, 50) + elementStream := make(chan *gdbi.GraphElement, 100) var retErrs []string - sem := make(chan struct{}, 100) - for { - var err error - class, err := stream.Recv() - if err == io.EOF { - break + var mu sync.Mutex + var wg sync.WaitGroup + var insertCount int32 = 0 + + // Receive first message + firstClass, err := stream.Recv() + if err != nil { + return err + } + + gdb, err := server.getGraphDB(firstClass.Graph) + if err != nil { + return err + } + + graphtwo, err := gdb.Graph(firstClass.Graph) + if err != nil { + log.WithFields(log.Fields{"error": err}).Error("BulkAddRaw: error") + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + err := graphtwo.BulkAdd(elementStream) + if err != nil { + log.WithFields(log.Fields{"graph": firstClass.Graph, "error": err}).Error("BulkAddRaw: error") + mu.Lock() + retErrs = append(retErrs, err.Error()) + mu.Unlock() } + }() + + // Process schema and stream + var populated bool + out := &graph.GraphSchema{Classes: map[string]*jsonschema.Schema{}, Compiler: nil} + processClass := func(class *gripql.RawJson) { if !populated { - sch, err = server.getGraph(class.Graph + "__schema__") + sch, err := server.getGraph(class.Graph + "__schema__") if err != nil { log.Errorf("Error loading schemas: %v", err) + mu.Lock() retErrs = append(retErrs, err.Error()) - break + mu.Unlock() + return } + mu.Lock() out, err = server.LoadSchemas(sch, out) + mu.Unlock() + if err != nil { log.Errorf("Error loading schemas: %v", err) + mu.Lock() retErrs = append(retErrs, err.Error()) - break + mu.Unlock() + return } populated = true } - gdb, err := server.getGraphDB(class.Graph) - if err != nil { - retErrs = append(retErrs, err.Error()) - break - } - - graph, err := gdb.Graph(class.Graph) - if err != nil { - log.WithFields(log.Fields{"error": err}).Error("BulkAddRaw: error") - retErrs = append(retErrs, err.Error()) - continue - } - - wg.Add(1) - go func() { - defer wg.Done() - defer func() { - for ge := range elementStream { - server.streamPool.Put(ge) - } - }() - err := graph.BulkAdd(elementStream) - if err != nil { - log.WithFields(log.Fields{"graph": class.Graph, "error": err}).Error("BulkAddRaw: error") - retErrs = append(retErrs, err.Error()) - } - }() - classData := class.Data.AsMap() - - // to generate grip data, need to know what type the data is. resourceType, ok := classData["resourceType"].(string) if !ok { log.WithFields(log.Fields{"error": fmt.Errorf("row %s does not have required field resourceType", classData)}).Error("BulkAddRaw: streaming error") + mu.Lock() retErrs = append(retErrs, fmt.Sprintf("row %s does not have required field resourceType", classData)) - continue + mu.Unlock() + return } - args := class.ExtraArgs.AsMap() + result, err := out.Generate(resourceType, classData, false, class.ExtraArgs.AsMap()) - result, err := out.Generate(resourceType, classData, false, args) if err != nil { log.WithFields(log.Fields{"error": err}).Errorf("BulkAddRaw: validation error for %s: %s", resourceType, classData) + mu.Lock() retErrs = append(retErrs, err.Error()) - continue + mu.Unlock() + return } for _, element := range result { - ge := server.streamPool.Get().(*gdbi.GraphElement) - ge.Graph = class.Graph if element.Vertex != nil { - ge.Vertex = &gdbi.Vertex{ - ID: element.Vertex.Gid, - Data: element.Vertex.Data.AsMap(), - Label: element.Vertex.Label, - } + elementStream <- &gdbi.GraphElement{ + Vertex: &gdbi.Vertex{ + ID: element.Vertex.Gid, + Data: element.Vertex.Data.AsMap(), + Label: element.Vertex.Label, + }, + Graph: class.Graph} } else { - ge.Edge = &gdbi.Edge{ - ID: element.Edge.Gid, - Label: element.Edge.Label, - From: element.Edge.From, - To: element.Edge.To, - Data: element.Edge.Data.AsMap(), - } + elementStream <- &gdbi.GraphElement{ + Edge: &gdbi.Edge{ + ID: element.Edge.Gid, + Label: element.Edge.Label, + From: element.Edge.From, + To: element.Edge.To, + Data: element.Edge.Data.AsMap(), + }, + Graph: class.Graph} } - sem <- struct{}{} - elementStream <- ge + mu.Lock() insertCount++ - <-sem + mu.Unlock() } - } - close(elementStream) + + processClass(firstClass) + + wg.Add(1) + go func() { + defer wg.Done() + defer close(elementStream) + + for { + class, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + mu.Lock() + retErrs = append(retErrs, err.Error()) + mu.Unlock() + break + } + processClass(class) + } + }() + wg.Wait() - close(sem) + return stream.SendAndClose(&gripql.BulkJsonEditResult{InsertCount: insertCount, Errors: retErrs}) } -// BulkAdd a stream of inputs and loads them into the graph func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { - var graphName string var insertCount int32 var errorCount int32 + var mu sync.Mutex - elementStream := make(chan *gdbi.GraphElement, 50) wg := &sync.WaitGroup{} + currentGraph := "" + var elementStream chan *gdbi.GraphElement - defer func() { - if elementStream != nil { - close(elementStream) - } - wg.Wait() - }() + // Function to start a new BulkAdd goroutine for a graph + startBulkAdd := func(graphName string, gdb gdbi.GraphInterface) chan *gdbi.GraphElement { + newStream := make(chan *gdbi.GraphElement, 100) + wg.Add(1) + go func(g gdbi.GraphInterface, stream chan *gdbi.GraphElement) { + defer wg.Done() + log.WithFields(log.Fields{"graph": graphName}).Info("BulkAdd: streaming elements to graph") + if err := g.BulkAdd(stream); err != nil { + log.WithFields(log.Fields{"graph": graphName, "error": err}).Error("BulkAdd: error") + mu.Lock() + errorCount++ + mu.Unlock() + } + }(gdb, newStream) + return newStream + } for { element, err := stream.Recv() @@ -357,57 +394,54 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { } if err != nil { log.WithFields(log.Fields{"error": err}).Error("BulkAdd: streaming error") + mu.Lock() errorCount++ + mu.Unlock() break } if isSchema(element.Graph) { err := "cannot add element to schema graph" log.WithFields(log.Fields{"error": err}).Error("BulkAdd: error") + mu.Lock() errorCount++ + mu.Unlock() continue } - // create a BulkAdd stream per graph - // close and switch when a new graph is encountered - if element.Graph != graphName { + // Switch graphs if needed + if element.Graph != currentGraph { if elementStream != nil { close(elementStream) - wg.Wait() } + gdb, err := server.getGraphDB(element.Graph) if err != nil { + mu.Lock() errorCount++ + mu.Unlock() continue } graph, err := gdb.Graph(element.Graph) if err != nil { log.WithFields(log.Fields{"error": err}).Error("BulkAdd: error") + mu.Lock() errorCount++ + mu.Unlock() continue } - graphName = element.Graph - elementStream = make(chan *gdbi.GraphElement, 50) - - wg.Add(1) - go func(graphName string, stream chan *gdbi.GraphElement) { - log.WithFields(log.Fields{"graph": element.Graph}).Info("BulkAdd: streaming elements to graph") - err := graph.BulkAdd(elementStream) - if err != nil { - log.WithFields(log.Fields{"graph": element.Graph, "error": err}).Error("BulkAdd: error") - // not a good representation of the true number of errors - errorCount++ - } - wg.Done() - }(graphName, elementStream) + currentGraph = element.Graph + elementStream = startBulkAdd(currentGraph, graph) } + // Process vertices if element.Vertex != nil { - err := element.Vertex.Validate() - if err != nil { + if err := element.Vertex.Validate(); err != nil { + mu.Lock() errorCount++ + mu.Unlock() log.WithFields(log.Fields{"graph": element.Graph, "error": err}).Errorf("BulkAdd: vertex validation failed for vertex: %#v", element.Vertex) } else { insertCount++ @@ -415,22 +449,29 @@ func (server *GripServer) BulkAdd(stream gripql.Edit_BulkAddServer) error { } } + // Process edges if element.Edge != nil { if element.Edge.Gid == "" { element.Edge.Gid = util.UUID() } - err := element.Edge.Validate() - if err != nil { + if err := element.Edge.Validate(); err != nil { + mu.Lock() errorCount++ + mu.Unlock() log.WithFields(log.Fields{"graph": element.Graph, "error": err}).Errorf("BulkAdd: edge validation failed for edge: %#v", element.Edge) } else { insertCount++ elementStream <- gdbi.NewGraphElement(element) - } } } + if elementStream != nil { + close(elementStream) + } + + wg.Wait() + return stream.SendAndClose(&gripql.BulkEditResult{InsertCount: insertCount, ErrorCount: errorCount}) } diff --git a/sqlite/graph.go b/sqlite/graph.go index 9b780606..5d3a7d40 100644 --- a/sqlite/graph.go +++ b/sqlite/graph.go @@ -96,7 +96,6 @@ func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error return fmt.Errorf("StreamVertices: Prepare Stmt: %v", err) } - count := 0 for v := range vertices { js, err := json.Marshal(v.Data) if err != nil { @@ -106,23 +105,6 @@ func (g *Graph) StreamVertices(vertices <-chan *gdbi.Vertex, workers int) error if err != nil { return fmt.Errorf("StreamVertices: Stmt.Exec: %v", err) } - count++ - - if count%1000 == 0 { - if err := txn.Commit(); err != nil { - _ = stmt.Close() - return fmt.Errorf("StreamVertices: Txn.Commit: %v", err) - } - - txn, err = g.db.Begin() - if err != nil { - return fmt.Errorf("StreamVertices: Begin New Txn: %v", err) - } - stmt, err = txn.Prepare(s) - if err != nil { - return fmt.Errorf("StreamVertices: Prepare New Stmt: %v", err) - } - } } @@ -160,7 +142,6 @@ func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { return fmt.Errorf("StreamEdges: Prepare Stmt: %v", err) } - count := 0 for e := range edges { js, err := json.Marshal(e.Data) if err != nil { @@ -170,22 +151,6 @@ func (g *Graph) StreamEdges(edges <-chan *gdbi.Edge, workers int) error { if err != nil { return fmt.Errorf("AddEdge: Stmt.Exec: %v", err) } - count++ - if count%1000 == 0 { - if err := txn.Commit(); err != nil { - _ = stmt.Close() - return fmt.Errorf("StreamEdges: Txn.Commit: %v", err) - } - - txn, err = g.db.Begin() - if err != nil { - return fmt.Errorf("StreamEdges: Begin New Txn: %v", err) - } - stmt, err = txn.Prepare(s) - if err != nil { - return fmt.Errorf("StreamEdges: Prepare New Stmt: %v", err) - } - } } diff --git a/util/insert.go b/util/insert.go index d37cf0c8..b923b8a9 100644 --- a/util/insert.go +++ b/util/insert.go @@ -65,11 +65,12 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, vertCount++ if len(vertexBatch) >= batchSize { + batchSizeToRelease := len(vertexBatch) for _, v := range vertexBatch { vertexChan <- v } vertexBatch = make([]*gdbi.Vertex, 0, batchSize) - sem.Release(int64(len(vertexBatch))) + sem.Release(int64(batchSizeToRelease)) } } else if element.Edge != nil { edge := element.Edge @@ -90,11 +91,12 @@ func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, edgeCount++ if len(edgeBatch) >= batchSize { + batchSizeToRelease := len(edgeBatch) for _, e := range edgeBatch { edgeChan <- e } edgeBatch = make([]*gdbi.Edge, 0, batchSize) - sem.Release(int64(len(edgeBatch))) + sem.Release(int64(batchSizeToRelease)) } } }